diff --git a/.asf.yaml b/.asf.yaml index 22042b355b2fa..3935a525ff3c4 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -31,6 +31,8 @@ github: merge: false squash: true rebase: true + ghp_branch: master + ghp_path: /docs notifications: pullrequests: reviews@spark.apache.org diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 0bf7e57c364e4..2b459e4c73bbb 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -304,7 +304,7 @@ jobs: uses: actions/upload-artifact@v4 with: name: unit-tests-log-${{ matrix.modules }}-${{ matrix.comment }}-${{ matrix.java }}-${{ matrix.hadoop }}-${{ matrix.hive }} - path: "**/target/unit-tests.log" + path: "**/target/*.log" infra-image: name: "Base image build" @@ -723,7 +723,7 @@ jobs: # See 'ipython_genutils' in SPARK-38517 # See 'docutils<0.18.0' in SPARK-39421 python3.9 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ - ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \ + ipython ipython_genutils sphinx_plotly_directive 'numpy==1.26.4' pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \ 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ 'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' diff --git a/.github/workflows/build_python_connect.yml b/.github/workflows/build_python_connect.yml index 3ac1a0117e41b..f668d813ef26e 100644 --- a/.github/workflows/build_python_connect.yml +++ b/.github/workflows/build_python_connect.yml @@ -71,7 +71,7 @@ jobs: python packaging/connect/setup.py sdist cd dist pip install pyspark*connect-*.tar.gz - pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' 'graphviz==0.20.3' torch torchvision torcheval deepspeed unittest-xml-reporting + pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' 'graphviz==0.20.3' torch torchvision torcheval deepspeed unittest-xml-reporting 'plotly>=4.8' - name: Run tests env: SPARK_TESTING: 1 diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml new file mode 100644 index 0000000000000..8faeb0557fbfb --- /dev/null +++ b/.github/workflows/pages.yml @@ -0,0 +1,97 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +name: GitHub Pages deployment + +on: + push: + branches: + - master + +concurrency: + group: 'docs preview' + cancel-in-progress: false + +jobs: + docs: + name: Build and deploy documentation + runs-on: ubuntu-latest + permissions: + id-token: write + pages: write + environment: + name: github-pages # https://github.com/actions/deploy-pages/issues/271 + env: + SPARK_TESTING: 1 # Reduce some noise in the logs + RELEASE_VERSION: 'In-Progress' + steps: + - name: Checkout Spark repository + uses: actions/checkout@v4 + with: + repository: apache/spark + ref: 'master' + - name: Install Java 17 + uses: actions/setup-java@v4 + with: + distribution: zulu + java-version: 17 + - name: Install Python 3.9 + uses: actions/setup-python@v5 + with: + python-version: '3.9' + architecture: x64 + cache: 'pip' + - name: Install Python dependencies + run: | + pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ + ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow 'pandas==2.2.2' 'plotly>=4.8' 'docutils<0.18.0' \ + 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ + 'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ + 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' + - name: Install Ruby for documentation generation + uses: ruby/setup-ruby@v1 + with: + ruby-version: '3.3' + bundler-cache: true + - name: Install Pandoc + run: | + sudo apt-get update -y + sudo apt-get install pandoc + - name: Install dependencies for documentation generation + run: | + cd docs + gem install bundler -v 2.4.22 -n /usr/local/bin + bundle install --retry=100 + - name: Run documentation build + run: | + sed -i".tmp1" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$RELEASE_VERSION"'/g' docs/_config.yml + sed -i".tmp2" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$RELEASE_VERSION"'/g' docs/_config.yml + sed -i".tmp3" "s/'facetFilters':.*$/'facetFilters': [\"version:$RELEASE_VERSION\"]/g" docs/_config.yml + sed -i".tmp4" 's/__version__: str = .*$/__version__: str = "'"$RELEASE_VERSION"'"/' python/pyspark/version.py + cd docs + SKIP_RDOC=1 bundle exec jekyll build + - name: Setup Pages + uses: actions/configure-pages@v5 + - name: Upload artifact + uses: actions/upload-pages-artifact@v3 + with: + path: 'docs/_site' + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/.github/workflows/test_report.yml b/.github/workflows/test_report.yml index c6225e6a1abe5..9ab69af42c818 100644 --- a/.github/workflows/test_report.yml +++ b/.github/workflows/test_report.yml @@ -30,14 +30,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Download test results to report - uses: dawidd6/action-download-artifact@09385b76de790122f4da9c82b17bccf858b9557c # pin@v2 + uses: dawidd6/action-download-artifact@bf251b5aa9c2f7eeb574a96ee720e24f801b7c11 # pin @v6 with: github_token: ${{ secrets.GITHUB_TOKEN }} workflow: ${{ github.event.workflow_run.workflow_id }} commit: ${{ github.event.workflow_run.head_commit.id }} workflow_conclusion: completed - name: Publish test report - uses: scacap/action-surefire-report@482f012643ed0560e23ef605a79e8e87ca081648 # pin@v1 + uses: scacap/action-surefire-report@a2911bd1a4412ec18dde2d93b1758b3e56d2a880 # pin @v1.8.0 with: check_name: Report test results github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/assembly/pom.xml b/assembly/pom.xml index e5628ce90fa90..01bd324efc118 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -123,7 +123,7 @@ com.google.guava @@ -200,7 +200,7 @@ cp - ${basedir}/../connector/connect/client/jvm/target/spark-connect-client-jvm_${scala.binary.version}-${version}.jar + ${basedir}/../connector/connect/client/jvm/target/spark-connect-client-jvm_${scala.binary.version}-${project.version}.jar ${basedir}/target/scala-${scala.binary.version}/jars/connect-repl @@ -339,6 +339,14 @@ + + + jjwt + + compile + + + diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index 5ed3048fb72b3..fb610a5d96f17 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -109,7 +109,7 @@ private static int lowercaseMatchLengthFrom( } // Compare the characters in the target and pattern strings. int matchLength = 0, codePointBuffer = -1, targetCodePoint, patternCodePoint; - while (targetIterator.hasNext() && patternIterator.hasNext()) { + while ((targetIterator.hasNext() || codePointBuffer != -1) && patternIterator.hasNext()) { if (codePointBuffer != -1) { targetCodePoint = codePointBuffer; codePointBuffer = -1; @@ -211,7 +211,7 @@ private static int lowercaseMatchLengthUntil( } // Compare the characters in the target and pattern strings. int matchLength = 0, codePointBuffer = -1, targetCodePoint, patternCodePoint; - while (targetIterator.hasNext() && patternIterator.hasNext()) { + while ((targetIterator.hasNext() || codePointBuffer != -1) && patternIterator.hasNext()) { if (codePointBuffer != -1) { targetCodePoint = codePointBuffer; codePointBuffer = -1; diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 5640a2468d02e..d5dbca7eb89bc 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -23,12 +23,14 @@ import java.util.function.Function; import java.util.function.BiFunction; import java.util.function.ToLongFunction; +import java.util.stream.Stream; +import com.ibm.icu.text.CollationKey; +import com.ibm.icu.text.Collator; import com.ibm.icu.text.RuleBasedCollator; import com.ibm.icu.text.StringSearch; import com.ibm.icu.util.ULocale; -import com.ibm.icu.text.CollationKey; -import com.ibm.icu.text.Collator; +import com.ibm.icu.util.VersionInfo; import org.apache.spark.SparkException; import org.apache.spark.unsafe.types.UTF8String; @@ -88,6 +90,17 @@ public Optional getVersion() { } } + public record CollationMeta( + String catalog, + String schema, + String collationName, + String language, + String country, + String icuVersion, + String padAttribute, + boolean accentSensitivity, + boolean caseSensitivity) { } + /** * Entry encapsulating all information about a collation. */ @@ -342,6 +355,23 @@ private static int collationNameToId(String collationName) throws SparkException } protected abstract Collation buildCollation(); + + protected abstract CollationMeta buildCollationMeta(); + + static List listCollations() { + return Stream.concat( + CollationSpecUTF8.listCollations().stream(), + CollationSpecICU.listCollations().stream()).toList(); + } + + static CollationMeta loadCollationMeta(CollationIdentifier collationIdentifier) { + CollationMeta collationSpecUTF8 = + CollationSpecUTF8.loadCollationMeta(collationIdentifier); + if (collationSpecUTF8 == null) { + return CollationSpecICU.loadCollationMeta(collationIdentifier); + } + return collationSpecUTF8; + } } private static class CollationSpecUTF8 extends CollationSpec { @@ -364,6 +394,9 @@ private enum CaseSensitivity { */ private static final int CASE_SENSITIVITY_MASK = 0b1; + private static final String UTF8_BINARY_COLLATION_NAME = "UTF8_BINARY"; + private static final String UTF8_LCASE_COLLATION_NAME = "UTF8_LCASE"; + private static final int UTF8_BINARY_COLLATION_ID = new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED).collationId; private static final int UTF8_LCASE_COLLATION_ID = @@ -406,7 +439,7 @@ private static CollationSpecUTF8 fromCollationId(int collationId) { protected Collation buildCollation() { if (collationId == UTF8_BINARY_COLLATION_ID) { return new Collation( - "UTF8_BINARY", + UTF8_BINARY_COLLATION_NAME, PROVIDER_SPARK, null, UTF8String::binaryCompare, @@ -417,7 +450,7 @@ protected Collation buildCollation() { /* supportsLowercaseEquality = */ false); } else { return new Collation( - "UTF8_LCASE", + UTF8_LCASE_COLLATION_NAME, PROVIDER_SPARK, null, CollationAwareUTF8String::compareLowerCase, @@ -428,6 +461,52 @@ protected Collation buildCollation() { /* supportsLowercaseEquality = */ true); } } + + @Override + protected CollationMeta buildCollationMeta() { + if (collationId == UTF8_BINARY_COLLATION_ID) { + return new CollationMeta( + CATALOG, + SCHEMA, + UTF8_BINARY_COLLATION_NAME, + /* language = */ null, + /* country = */ null, + /* icuVersion = */ null, + COLLATION_PAD_ATTRIBUTE, + /* accentSensitivity = */ true, + /* caseSensitivity = */ true); + } else { + return new CollationMeta( + CATALOG, + SCHEMA, + UTF8_LCASE_COLLATION_NAME, + /* language = */ null, + /* country = */ null, + /* icuVersion = */ null, + COLLATION_PAD_ATTRIBUTE, + /* accentSensitivity = */ true, + /* caseSensitivity = */ false); + } + } + + static List listCollations() { + CollationIdentifier UTF8_BINARY_COLLATION_IDENT = + new CollationIdentifier(PROVIDER_SPARK, UTF8_BINARY_COLLATION_NAME, "1.0"); + CollationIdentifier UTF8_LCASE_COLLATION_IDENT = + new CollationIdentifier(PROVIDER_SPARK, UTF8_LCASE_COLLATION_NAME, "1.0"); + return Arrays.asList(UTF8_BINARY_COLLATION_IDENT, UTF8_LCASE_COLLATION_IDENT); + } + + static CollationMeta loadCollationMeta(CollationIdentifier collationIdentifier) { + try { + int collationId = CollationSpecUTF8.collationNameToId( + collationIdentifier.name, collationIdentifier.name.toUpperCase()); + return CollationSpecUTF8.fromCollationId(collationId).buildCollationMeta(); + } catch (SparkException ignored) { + // ignore + return null; + } + } } private static class CollationSpecICU extends CollationSpec { @@ -684,6 +763,20 @@ protected Collation buildCollation() { /* supportsLowercaseEquality = */ false); } + @Override + protected CollationMeta buildCollationMeta() { + return new CollationMeta( + CATALOG, + SCHEMA, + collationName(), + ICULocaleMap.get(locale).getDisplayLanguage(), + ICULocaleMap.get(locale).getDisplayCountry(), + VersionInfo.ICU_VERSION.toString(), + COLLATION_PAD_ATTRIBUTE, + accentSensitivity == AccentSensitivity.AS, + caseSensitivity == CaseSensitivity.CS); + } + /** * Compute normalized collation name. Components of collation name are given in order: * - Locale name @@ -704,6 +797,37 @@ private String collationName() { } return builder.toString(); } + + private static List allCollationNames() { + List collationNames = new ArrayList<>(); + for (String locale: ICULocaleToId.keySet()) { + // CaseSensitivity.CS + AccentSensitivity.AS + collationNames.add(locale); + // CaseSensitivity.CS + AccentSensitivity.AI + collationNames.add(locale + "_AI"); + // CaseSensitivity.CI + AccentSensitivity.AS + collationNames.add(locale + "_CI"); + // CaseSensitivity.CI + AccentSensitivity.AI + collationNames.add(locale + "_CI_AI"); + } + return collationNames.stream().sorted().toList(); + } + + static List listCollations() { + return allCollationNames().stream().map(name -> + new CollationIdentifier(PROVIDER_ICU, name, VersionInfo.ICU_VERSION.toString())).toList(); + } + + static CollationMeta loadCollationMeta(CollationIdentifier collationIdentifier) { + try { + int collationId = CollationSpecICU.collationNameToId( + collationIdentifier.name, collationIdentifier.name.toUpperCase()); + return CollationSpecICU.fromCollationId(collationId).buildCollationMeta(); + } catch (SparkException ignored) { + // ignore + return null; + } + } } /** @@ -730,9 +854,12 @@ public CollationIdentifier identifier() { } } + public static final String CATALOG = "SYSTEM"; + public static final String SCHEMA = "BUILTIN"; public static final String PROVIDER_SPARK = "spark"; public static final String PROVIDER_ICU = "icu"; public static final List SUPPORTED_PROVIDERS = List.of(PROVIDER_SPARK, PROVIDER_ICU); + public static final String COLLATION_PAD_ATTRIBUTE = "NO_PAD"; public static final int UTF8_BINARY_COLLATION_ID = Collation.CollationSpecUTF8.UTF8_BINARY_COLLATION_ID; @@ -794,6 +921,18 @@ public static int collationNameToId(String collationName) throws SparkException return Collation.CollationSpec.collationNameToId(collationName); } + /** + * Returns whether the ICU collation is not Case Sensitive Accent Insensitive + * for the given collation id. + * This method is used in expressions which do not support CS_AI collations. + */ + public static boolean isCaseSensitiveAndAccentInsensitive(int collationId) { + return Collation.CollationSpecICU.fromCollationId(collationId).caseSensitivity == + Collation.CollationSpecICU.CaseSensitivity.CS && + Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity == + Collation.CollationSpecICU.AccentSensitivity.AI; + } + public static void assertValidProvider(String provider) throws SparkException { if (!SUPPORTED_PROVIDERS.contains(provider.toLowerCase())) { Map params = Map.of( @@ -923,4 +1062,12 @@ public static String getClosestSuggestionsOnInvalidName( return String.join(", ", suggestions); } + + public static List listCollations() { + return Collation.CollationSpec.listCollations(); + } + + public static CollationMeta loadCollationMeta(CollationIdentifier collationIdentifier) { + return Collation.CollationSpec.loadCollationMeta(collationIdentifier); + } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 5719303a0dce8..a445cde52ad57 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -629,6 +629,8 @@ public void testStartsWith() throws SparkException { assertStartsWith("İonic", "Io", "UTF8_LCASE", false); assertStartsWith("İonic", "i\u0307o", "UTF8_LCASE", true); assertStartsWith("İonic", "İo", "UTF8_LCASE", true); + assertStartsWith("oİ", "oİ", "UTF8_LCASE", true); + assertStartsWith("oİ", "oi̇", "UTF8_LCASE", true); // Conditional case mapping (e.g. Greek sigmas). assertStartsWith("σ", "σ", "UTF8_BINARY", true); assertStartsWith("σ", "ς", "UTF8_BINARY", false); @@ -880,6 +882,8 @@ public void testEndsWith() throws SparkException { assertEndsWith("the İo", "Io", "UTF8_LCASE", false); assertEndsWith("the İo", "i\u0307o", "UTF8_LCASE", true); assertEndsWith("the İo", "İo", "UTF8_LCASE", true); + assertEndsWith("İo", "İo", "UTF8_LCASE", true); + assertEndsWith("İo", "i̇o", "UTF8_LCASE", true); // Conditional case mapping (e.g. Greek sigmas). assertEndsWith("σ", "σ", "UTF8_BINARY", true); assertEndsWith("σ", "ς", "UTF8_BINARY", false); diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index e2725a98a63bd..e83202d9e5ee3 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1,4 +1,10 @@ { + "ADD_DEFAULT_UNSUPPORTED" : { + "message" : [ + "Failed to execute command because DEFAULT values are not supported when adding new columns to previously existing target data source with table provider: \"\"." + ], + "sqlState" : "42623" + }, "AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION" : { "message" : [ "Non-deterministic expression should not appear in the arguments of an aggregate function." @@ -434,6 +440,12 @@ ], "sqlState" : "42846" }, + "CANNOT_USE_KRYO" : { + "message" : [ + "Cannot load Kryo serialization codec. Kryo serialization cannot be used in the Spark Connect client. Use Java serialization, provide a custom Codec, or use Spark Classic instead." + ], + "sqlState" : "22KD3" + }, "CANNOT_WRITE_STATE_STORE" : { "message" : [ "Error writing state store files for provider ." @@ -449,13 +461,13 @@ }, "CAST_INVALID_INPUT" : { "message" : [ - "The value of the type cannot be cast to because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary set to \"false\" to bypass this error." + "The value of the type cannot be cast to because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead." ], "sqlState" : "22018" }, "CAST_OVERFLOW" : { "message" : [ - "The value of the type cannot be cast to due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set to \"false\" to bypass this error." + "The value of the type cannot be cast to due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead." ], "sqlState" : "22003" }, @@ -898,7 +910,7 @@ }, "NON_STRING_TYPE" : { "message" : [ - "all arguments must be strings." + "all arguments of the function must be strings." ] }, "NULL_TYPE" : { @@ -1039,6 +1051,12 @@ ], "sqlState" : "42710" }, + "DATA_SOURCE_EXTERNAL_ERROR" : { + "message" : [ + "Encountered error when saving to external data source." + ], + "sqlState" : "KD010" + }, "DATA_SOURCE_NOT_EXIST" : { "message" : [ "Data source '' not found. Please make sure the data source is registered." @@ -1084,6 +1102,12 @@ ], "sqlState" : "42608" }, + "DEFAULT_UNSUPPORTED" : { + "message" : [ + "Failed to execute command because DEFAULT values are not supported for target data source with table provider: \"\"." + ], + "sqlState" : "42623" + }, "DISTINCT_WINDOW_FUNCTION_UNSUPPORTED" : { "message" : [ "Distinct window functions are not supported: ." @@ -1432,6 +1456,12 @@ ], "sqlState" : "2203G" }, + "FAILED_TO_LOAD_ROUTINE" : { + "message" : [ + "Failed to load routine ." + ], + "sqlState" : "38000" + }, "FAILED_TO_PARSE_TOO_COMPLEX" : { "message" : [ "The statement, including potential SQL functions and referenced views, was too complex to parse.", @@ -1457,6 +1487,12 @@ ], "sqlState" : "42704" }, + "FLATMAPGROUPSWITHSTATE_USER_FUNCTION_ERROR" : { + "message" : [ + "An error occurred in the user provided function in flatMapGroupsWithState. Reason: " + ], + "sqlState" : "39000" + }, "FORBIDDEN_OPERATION" : { "message" : [ "The operation is not allowed on the : ." @@ -1469,6 +1505,12 @@ ], "sqlState" : "39000" }, + "FOREACH_USER_FUNCTION_ERROR" : { + "message" : [ + "An error occurred in the user provided function in foreach sink. Reason: " + ], + "sqlState" : "39000" + }, "FOUND_MULTIPLE_DATA_SOURCES" : { "message" : [ "Detected multiple data sources with the name ''. Please check the data source isn't simultaneously registered and located in the classpath." @@ -1565,6 +1607,36 @@ ], "sqlState" : "42601" }, + "IDENTITY_COLUMNS_DUPLICATED_SEQUENCE_GENERATOR_OPTION" : { + "message" : [ + "Duplicated IDENTITY column sequence generator option: ." + ], + "sqlState" : "42601" + }, + "IDENTITY_COLUMNS_ILLEGAL_STEP" : { + "message" : [ + "IDENTITY column step cannot be 0." + ], + "sqlState" : "42611" + }, + "IDENTITY_COLUMNS_UNSUPPORTED_DATA_TYPE" : { + "message" : [ + "DataType is not supported for IDENTITY columns." + ], + "sqlState" : "428H2" + }, + "IDENTITY_COLUMN_WITH_DEFAULT_VALUE" : { + "message" : [ + "A column cannot have both a default value and an identity column specification but column has default value: () and identity column specification: ()." + ], + "sqlState" : "42623" + }, + "ILLEGAL_DAY_OF_WEEK" : { + "message" : [ + "Illegal input for day of week: ." + ], + "sqlState" : "22009" + }, "ILLEGAL_STATE_STORE_VALUE" : { "message" : [ "Illegal value provided to the State Store" @@ -1942,6 +2014,12 @@ }, "sqlState" : "42903" }, + "INVALID_AGNOSTIC_ENCODER" : { + "message" : [ + "Found an invalid agnostic encoder. Expects an instance of AgnosticEncoder but got . For more information consult '/api/java/index.html?org/apache/spark/sql/Encoder.html'." + ], + "sqlState" : "42001" + }, "INVALID_ARRAY_INDEX" : { "message" : [ "The index is out of bounds. The array has elements. Use the SQL function `get()` to tolerate accessing element at invalid index and return NULL instead. If necessary set to \"false\" to bypass this error." @@ -2074,6 +2152,11 @@ "message" : [ "Too many letters in datetime pattern: . Please reduce pattern length." ] + }, + "SECONDS_FRACTION" : { + "message" : [ + "Cannot detect a seconds fraction pattern of variable length. Please make sure the pattern contains 'S', and does not contain illegal characters." + ] } }, "sqlState" : "22007" @@ -2372,6 +2455,11 @@ "Uncaught arithmetic exception while parsing ''." ] }, + "DAY_TIME_PARSING" : { + "message" : [ + "Error parsing interval day-time string: ." + ] + }, "INPUT_IS_EMPTY" : { "message" : [ "Interval string cannot be empty." @@ -2382,6 +2470,11 @@ "Interval string cannot be null." ] }, + "INTERVAL_PARSING" : { + "message" : [ + "Error parsing interval string." + ] + }, "INVALID_FRACTION" : { "message" : [ " cannot have fractional part." @@ -2417,15 +2510,35 @@ "Expect a unit name after but hit EOL." ] }, + "SECOND_NANO_FORMAT" : { + "message" : [ + "Interval string does not match second-nano format of ss.nnnnnnnnn." + ] + }, "UNKNOWN_PARSING_ERROR" : { "message" : [ "Unknown error when parsing ." ] }, + "UNMATCHED_FORMAT_STRING" : { + "message" : [ + "Interval string does not match format of when cast to : ." + ] + }, + "UNMATCHED_FORMAT_STRING_WITH_NOTICE" : { + "message" : [ + "Interval string does not match format of when cast to : . Set \"spark.sql.legacy.fromDayTimeString.enabled\" to \"true\" to restore the behavior before Spark 3.0." + ] + }, "UNRECOGNIZED_NUMBER" : { "message" : [ "Unrecognized number ." ] + }, + "UNSUPPORTED_FROM_TO_EXPRESSION" : { + "message" : [ + "Cannot support (interval '' to ) expression." + ] } }, "sqlState" : "22006" @@ -2489,6 +2602,24 @@ ], "sqlState" : "F0000" }, + "INVALID_LABEL_USAGE" : { + "message" : [ + "The usage of the label is invalid." + ], + "subClass" : { + "DOES_NOT_EXIST" : { + "message" : [ + "Label was used in the statement, but the label does not belong to any surrounding block." + ] + }, + "ITERATE_IN_COMPOUND" : { + "message" : [ + "ITERATE statement cannot be used with a label that belongs to a compound (BEGIN...END) body." + ] + } + }, + "sqlState" : "42K0L" + }, "INVALID_LAMBDA_FUNCTION_CALL" : { "message" : [ "Invalid lambda function call." @@ -3041,12 +3172,12 @@ "subClass" : { "NOT_ALLOWED_IN_SCOPE" : { "message" : [ - "Variable was declared on line , which is not allowed in this scope." + "Declaration of the variable is not allowed in this scope." ] }, "ONLY_AT_BEGINNING" : { "message" : [ - "Variable can only be declared at the beginning of the compound, but it was declared on line ." + "Variable can only be declared at the beginning of the compound." ] } }, @@ -3671,6 +3802,12 @@ ], "sqlState" : "42K03" }, + "PIPE_OPERATOR_SELECT_CONTAINS_AGGREGATE_FUNCTION" : { + "message" : [ + "Aggregate function is not allowed when using the pipe operator |> SELECT clause; please use the pipe operator |> AGGREGATE clause instead" + ], + "sqlState" : "0A000" + }, "PIVOT_VALUE_DATA_TYPE_MISMATCH" : { "message" : [ "Invalid pivot value '': value data type does not match pivot column data type ." @@ -4326,6 +4463,24 @@ ], "sqlState" : "428EK" }, + "TRANSPOSE_EXCEED_ROW_LIMIT" : { + "message" : [ + "Number of rows exceeds the allowed limit of for TRANSPOSE. If this was intended, set to at least the current row count." + ], + "sqlState" : "54006" + }, + "TRANSPOSE_INVALID_INDEX_COLUMN" : { + "message" : [ + "Invalid index column for TRANSPOSE because: " + ], + "sqlState" : "42804" + }, + "TRANSPOSE_NO_LEAST_COMMON_TYPE" : { + "message" : [ + "Transpose requires non-index columns to share a least common type, but and do not." + ], + "sqlState" : "42K09" + }, "UDTF_ALIAS_NUMBER_MISMATCH" : { "message" : [ "The number of aliases supplied in the AS clause does not match the number of columns output by the UDTF.", @@ -5199,6 +5354,11 @@ "" ] }, + "SCALAR_SUBQUERY_IN_VALUES" : { + "message" : [ + "Scalar subqueries in the VALUES clause." + ] + }, "UNSUPPORTED_CORRELATED_EXPRESSION_IN_JOIN_CONDITION" : { "message" : [ "Correlated subqueries in the join predicate cannot reference both join inputs:", @@ -6168,7 +6328,7 @@ "Detected implicit cartesian product for join between logical plans", "", "and", - "rightPlan", + "", "Join condition is missing or trivial.", "Either: use the CROSS JOIN syntax to allow cartesian products between these relations, or: enable implicit cartesian products by setting the configuration variable spark.sql.crossJoin.enabled=true." ] @@ -6531,21 +6691,6 @@ "Sinks cannot request distribution and ordering in continuous execution mode." ] }, - "_LEGACY_ERROR_TEMP_1344" : { - "message" : [ - "Invalid DEFAULT value for column : fails to parse as a valid literal value." - ] - }, - "_LEGACY_ERROR_TEMP_1345" : { - "message" : [ - "Failed to execute command because DEFAULT values are not supported for target data source with table provider: \"\"." - ] - }, - "_LEGACY_ERROR_TEMP_1346" : { - "message" : [ - "Failed to execute command because DEFAULT values are not supported when adding new columns to previously existing target data source with table provider: \"\"." - ] - }, "_LEGACY_ERROR_TEMP_2000" : { "message" : [ ". If necessary set to false to bypass this error." @@ -6561,11 +6706,6 @@ "Type does not support ordered operations." ] }, - "_LEGACY_ERROR_TEMP_2011" : { - "message" : [ - "Unexpected data type ." - ] - }, "_LEGACY_ERROR_TEMP_2013" : { "message" : [ "Negative values found in " @@ -7773,7 +7913,7 @@ }, "_LEGACY_ERROR_TEMP_3055" : { "message" : [ - "ScalarFunction '' neither implement magic method nor override 'produceResult'" + "ScalarFunction neither implement magic method nor override 'produceResult'" ] }, "_LEGACY_ERROR_TEMP_3056" : { @@ -8379,36 +8519,6 @@ "The number of fields () in the partition identifier is not equal to the partition schema length (). The identifier might not refer to one partition." ] }, - "_LEGACY_ERROR_TEMP_3209" : { - "message" : [ - "Illegal input for day of week: " - ] - }, - "_LEGACY_ERROR_TEMP_3210" : { - "message" : [ - "Interval string does not match second-nano format of ss.nnnnnnnnn" - ] - }, - "_LEGACY_ERROR_TEMP_3211" : { - "message" : [ - "Error parsing interval day-time string: " - ] - }, - "_LEGACY_ERROR_TEMP_3212" : { - "message" : [ - "Cannot support (interval '' to ) expression" - ] - }, - "_LEGACY_ERROR_TEMP_3213" : { - "message" : [ - "Error parsing interval string: " - ] - }, - "_LEGACY_ERROR_TEMP_3214" : { - "message" : [ - "Interval string does not match format of when cast to : " - ] - }, "_LEGACY_ERROR_TEMP_3215" : { "message" : [ "Expected a Boolean type expression in replaceNullWithFalse, but got the type in ." diff --git a/common/utils/src/main/resources/error/error-states.json b/common/utils/src/main/resources/error/error-states.json index c369db3f65058..87811fef9836e 100644 --- a/common/utils/src/main/resources/error/error-states.json +++ b/common/utils/src/main/resources/error/error-states.json @@ -7417,6 +7417,12 @@ "standard": "N", "usedBy": ["Databricks"] }, + "KD010": { + "description": "external data source failure", + "origin": "Databricks", + "standard": "N", + "usedBy": ["Databricks"] + }, "P0000": { "description": "procedural logic error", "origin": "PostgreSQL", diff --git a/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala b/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala index a1934dcf7a007..e2dd0da1aac85 100644 --- a/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala +++ b/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala @@ -19,7 +19,6 @@ package org.apache.spark import java.net.URL -import scala.collection.immutable.Map import scala.jdk.CollectionConverters._ import com.fasterxml.jackson.annotation.JsonIgnore @@ -52,7 +51,7 @@ class ErrorClassesJsonReader(jsonFileURLs: Seq[URL]) { val sub = new StringSubstitutor(sanitizedParameters.asJava) sub.setEnableUndefinedVariableException(true) sub.setDisableSubstitutionInValues(true) - try { + val errorMessage = try { sub.replace(ErrorClassesJsonReader.TEMPLATE_REGEX.replaceAllIn( messageTemplate, "\\$\\{$1\\}")) } catch { @@ -61,6 +60,17 @@ class ErrorClassesJsonReader(jsonFileURLs: Seq[URL]) { s"MessageTemplate: $messageTemplate, " + s"Parameters: $messageParameters", i) } + if (util.SparkEnvUtils.isTesting) { + val placeHoldersNum = ErrorClassesJsonReader.TEMPLATE_REGEX.findAllIn(messageTemplate).length + if (placeHoldersNum < sanitizedParameters.size) { + throw SparkException.internalError( + s"Found unused message parameters of the error class '$errorClass'. " + + s"Its error message format has $placeHoldersNum placeholders, " + + s"but the passed message parameters map has ${sanitizedParameters.size} items. " + + "Consider to add placeholders to the error format or remove unused message parameters.") + } + } + errorMessage } def getMessageParameters(errorClass: String): Seq[String] = { diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala index a7e4f186000b5..12d456a371d07 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala @@ -266,6 +266,7 @@ private[spark] object LogKeys { case object FEATURE_NAME extends LogKey case object FETCH_SIZE extends LogKey case object FIELD_NAME extends LogKey + case object FIELD_TYPE extends LogKey case object FILES extends LogKey case object FILE_ABSOLUTE_PATH extends LogKey case object FILE_END_OFFSET extends LogKey @@ -652,6 +653,7 @@ private[spark] object LogKeys { case object RECEIVER_IDS extends LogKey case object RECORDS extends LogKey case object RECOVERY_STATE extends LogKey + case object RECURSIVE_DEPTH extends LogKey case object REDACTED_STATEMENT extends LogKey case object REDUCE_ID extends LogKey case object REGEX extends LogKey diff --git a/common/utils/src/main/scala/org/apache/spark/scheduler/SparkListenerEvent.scala b/common/utils/src/main/scala/org/apache/spark/scheduler/SparkListenerEvent.scala new file mode 100644 index 0000000000000..1f1d3492d6ac5 --- /dev/null +++ b/common/utils/src/main/scala/org/apache/spark/scheduler/SparkListenerEvent.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler + +import com.fasterxml.jackson.annotation.JsonTypeInfo + +import org.apache.spark.annotation.DeveloperApi + +@DeveloperApi +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") +trait SparkListenerEvent { + /* Whether output this event to the event log */ + protected[spark] def logEvent: Boolean = true +} diff --git a/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala index 42a1d1612aeeb..d54a2f2ed9cea 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala @@ -342,7 +342,7 @@ private[spark] object MavenUtils extends Logging { } /* Set ivy settings for location of cache, if option is supplied */ - private def processIvyPathArg(ivySettings: IvySettings, ivyPath: Option[String]): Unit = { + private[util] def processIvyPathArg(ivySettings: IvySettings, ivyPath: Option[String]): Unit = { val alternateIvyDir = ivyPath.filterNot(_.trim.isEmpty).getOrElse { // To protect old Ivy-based systems like old Spark from Apache Ivy 2.5.2's incompatibility. System.getProperty("ivy.home", diff --git a/common/utils/src/test/scala/org/apache/spark/util/IvyTestUtils.scala b/common/utils/src/test/scala/org/apache/spark/util/IvyTestUtils.scala index 76062074edcaf..140de836622f4 100644 --- a/common/utils/src/test/scala/org/apache/spark/util/IvyTestUtils.scala +++ b/common/utils/src/test/scala/org/apache/spark/util/IvyTestUtils.scala @@ -365,7 +365,7 @@ private[spark] object IvyTestUtils { useIvyLayout: Boolean = false, withPython: Boolean = false, withR: Boolean = false, - ivySettings: IvySettings = new IvySettings)(f: String => Unit): Unit = { + ivySettings: IvySettings = defaultIvySettings())(f: String => Unit): Unit = { val deps = dependencies.map(MavenUtils.extractMavenCoordinates) purgeLocalIvyCache(artifact, deps, ivySettings) val repo = createLocalRepositoryForTests(artifact, dependencies, rootDir, useIvyLayout, @@ -401,4 +401,16 @@ private[spark] object IvyTestUtils { } } } + + /** + * Creates and initializes a new instance of IvySettings with default configurations. + * The method processes the Ivy path argument using MavenUtils to ensure proper setup. + * + * @return A newly created and configured instance of IvySettings. + */ + private def defaultIvySettings(): IvySettings = { + val settings = new IvySettings + MavenUtils.processIvyPathArg(ivySettings = settings, ivyPath = None) + settings + } } diff --git a/common/variant/README.md b/common/variant/README.md index a66d708da75bf..4ed7c16f5b6ed 100644 --- a/common/variant/README.md +++ b/common/variant/README.md @@ -333,27 +333,27 @@ The Decimal type contains a scale, but no precision. The implied precision of a | Object | `2` | A collection of (string-key, variant-value) pairs | | Array | `3` | An ordered sequence of variant values | -| Primitive Type | Type ID | Equivalent Parquet Type | Binary format | -|-----------------------------|---------|-----------------------------|---------------------------------------------------------------------------------------------------------------------| -| null | `0` | any | none | -| boolean (True) | `1` | BOOLEAN | none | -| boolean (False) | `2` | BOOLEAN | none | -| int8 | `3` | INT(8, signed) | 1 byte | -| int16 | `4` | INT(16, signed) | 2 byte little-endian | -| int32 | `5` | INT(32, signed) | 4 byte little-endian | -| int64 | `6` | INT(64, signed) | 8 byte little-endian | -| double | `7` | DOUBLE | IEEE little-endian | -| decimal4 | `8` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | -| decimal8 | `9` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | -| decimal16 | `10` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | -| date | `11` | DATE | 4 byte little-endian | -| timestamp | `12` | TIMESTAMP(true, MICROS) | 8-byte little-endian | -| timestamp without time zone | `13` | TIMESTAMP(false, MICROS) | 8-byte little-endian | -| float | `14` | FLOAT | IEEE little-endian | -| binary | `15` | BINARY | 4 byte little-endian size, followed by bytes | -| string | `16` | STRING | 4 byte little-endian size, followed by UTF-8 encoded bytes | -| year-month interval | `19` | INT(32, signed)1 | 1 byte denoting start field (1 bit) and end field (1 bit) starting at LSB followed by 4-byte little-endian value. | -| day-time interval | `20` | INT(64, signed)1 | 1 byte denoting start field (2 bits) and end field (2 bits) starting at LSB followed by 8-byte little-endian value. | +| Logical Type | Physical Type | Type ID | Equivalent Parquet Type | Binary format | +|----------------------|-----------------------------|---------|-----------------------------|---------------------------------------------------------------------------------------------------------------------| +| NullType | null | `0` | any | none | +| Boolean | boolean (True) | `1` | BOOLEAN | none | +| Boolean | boolean (False) | `2` | BOOLEAN | none | +| Exact Numeric | int8 | `3` | INT(8, signed) | 1 byte | +| Exact Numeric | int16 | `4` | INT(16, signed) | 2 byte little-endian | +| Exact Numeric | int32 | `5` | INT(32, signed) | 4 byte little-endian | +| Exact Numeric | int64 | `6` | INT(64, signed) | 8 byte little-endian | +| Double | double | `7` | DOUBLE | IEEE little-endian | +| Exact Numeric | decimal4 | `8` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | +| Exact Numeric | decimal8 | `9` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | +| Exact Numeric | decimal16 | `10` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | +| Date | date | `11` | DATE | 4 byte little-endian | +| Timestamp | timestamp | `12` | TIMESTAMP(true, MICROS) | 8-byte little-endian | +| TimestampNTZ | timestamp without time zone | `13` | TIMESTAMP(false, MICROS) | 8-byte little-endian | +| Float | float | `14` | FLOAT | IEEE little-endian | +| Binary | binary | `15` | BINARY | 4 byte little-endian size, followed by bytes | +| String | string | `16` | STRING | 4 byte little-endian size, followed by UTF-8 encoded bytes | +| YMInterval | year-month interval | `19` | INT(32, signed)1 | 1 byte denoting start field (1 bit) and end field (1 bit) starting at LSB followed by 4-byte little-endian value. | +| DTInterval | day-time interval | `20` | INT(64, signed)1 | 1 byte denoting start field (2 bits) and end field (2 bits) starting at LSB followed by 8-byte little-endian value. | | Decimal Precision | Decimal value type | |-----------------------|--------------------| @@ -362,6 +362,8 @@ The Decimal type contains a scale, but no precision. The implied precision of a | 18 <= precision <= 38 | int128 | | > 38 | Not supported | +The *Logical Type* column indicates logical equivalence of physically encoded types. For example, a user expression operating on a string value containing "hello" should behave the same, whether it is encoded with the short string optimization, or long string encoding. Similarly, user expressions operating on an *int8* value of 1 should behave the same as a decimal16 with scale 2 and unscaled value 100. + The year-month and day-time interval types have one byte at the beginning indicating the start and end fields. In the case of the year-month interval, the least significant bit denotes the start field and the next least significant bit denotes the end field. The remaining 6 bits are unused. A field value of 0 represents YEAR and 1 represents MONTH. In the case of the day-time interval, the least significant 2 bits denote the start field and the next least significant 2 bits denote the end field. The remaining 4 bits are unused. A field value of 0 represents DAY, 1 represents HOUR, 2 represents MINUTE, and 3 represents SECOND. Type IDs 17 and 18 were originally reserved for a prototype feature (string-from-metadata) that was never implemented. These IDs are available for use by new types. diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index 7d80998d96eb1..0b85b208242cb 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -42,7 +42,8 @@ private[sql] case class AvroDataToCatalyst( val dt = SchemaConverters.toSqlType( expectedSchema, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType).dataType + avroOptions.stableIdPrefixForUnionType, + avroOptions.recursiveFieldMaxDepth).dataType parseMode match { // With PermissiveMode, the output Catalyst row might contain columns of null values for // corrupt records, even if some of the columns are not nullable in the user-provided schema. @@ -69,7 +70,8 @@ private[sql] case class AvroDataToCatalyst( dataType, avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType) + avroOptions.stableIdPrefixForUnionType, + avroOptions.recursiveFieldMaxDepth) @transient private var decoder: BinaryDecoder = _ diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 877c3f89e88c0..ac20614553ca2 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -51,14 +51,16 @@ private[sql] class AvroDeserializer( datetimeRebaseSpec: RebaseSpec, filters: StructFilters, useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String) { + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int) { def this( rootAvroType: Schema, rootCatalystType: DataType, datetimeRebaseMode: String, useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String) = { + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int) = { this( rootAvroType, rootCatalystType, @@ -66,7 +68,8 @@ private[sql] class AvroDeserializer( RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)), new NoopFilters, useStableIdForUnionType, - stableIdPrefixForUnionType) + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) } private lazy val decimalConversions = new DecimalConversion() @@ -128,7 +131,8 @@ private[sql] class AvroDeserializer( s"schema is incompatible (avroType = $avroType, sqlType = ${catalystType.sql})" val realDataType = SchemaConverters.toSqlType( - avroType, useStableIdForUnionType, stableIdPrefixForUnionType).dataType + avroType, useStableIdForUnionType, stableIdPrefixForUnionType, + recursiveFieldMaxDepth).dataType (avroType.getType, catalystType) match { case (NULL, NullType) => (updater, ordinal, _) => diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 372f24b54f5c4..264c3a1f48abe 100755 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -145,7 +145,8 @@ private[sql] class AvroFileFormat extends FileFormat datetimeRebaseMode, avroFilters, parsedOptions.useStableIdForUnionType, - parsedOptions.stableIdPrefixForUnionType) + parsedOptions.stableIdPrefixForUnionType, + parsedOptions.recursiveFieldMaxDepth) override val stopPosition = file.start + file.length override def hasNext: Boolean = hasNextRow diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala index 4332904339f19..e0c6ad3ee69d3 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailFastMode, ParseMode} +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf /** @@ -136,6 +137,15 @@ private[sql] class AvroOptions( val stableIdPrefixForUnionType: String = parameters .getOrElse(STABLE_ID_PREFIX_FOR_UNION_TYPE, "member_") + + val recursiveFieldMaxDepth: Int = + parameters.get(RECURSIVE_FIELD_MAX_DEPTH).map(_.toInt).getOrElse(-1) + + if (recursiveFieldMaxDepth > RECURSIVE_FIELD_MAX_DEPTH_LIMIT) { + throw QueryCompilationErrors.avroOptionsException( + RECURSIVE_FIELD_MAX_DEPTH, + s"Should not be greater than $RECURSIVE_FIELD_MAX_DEPTH_LIMIT.") + } } private[sql] object AvroOptions extends DataSourceOptions { @@ -170,4 +180,25 @@ private[sql] object AvroOptions extends DataSourceOptions { // When STABLE_ID_FOR_UNION_TYPE is enabled, the option allows to configure the prefix for fields // of Avro Union type. val STABLE_ID_PREFIX_FOR_UNION_TYPE = newOption("stableIdentifierPrefixForUnionType") + + /** + * Adds support for recursive fields. If this option is not specified or is set to 0, recursive + * fields are not permitted. Setting it to 1 drops all recursive fields, 2 allows recursive + * fields to be recursed once, and 3 allows it to be recursed twice and so on, up to 15. + * Values larger than 15 are not allowed in order to avoid inadvertently creating very large + * schemas. If an avro message has depth beyond this limit, the Spark struct returned is + * truncated after the recursion limit. + * + * Examples: Consider an Avro schema with a recursive field: + * {"type" : "record", "name" : "Node", "fields" : [{"name": "Id", "type": "int"}, + * {"name": "Next", "type": ["null", "Node"]}]} + * The following lists the parsed schema with different values for this setting. + * 1: `struct` + * 2: `struct>` + * 3: `struct>>` + * and so on. + */ + val RECURSIVE_FIELD_MAX_DEPTH = newOption("recursiveFieldMaxDepth") + + val RECURSIVE_FIELD_MAX_DEPTH_LIMIT: Int = 15 } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index 7cbc30f1fb3dc..594ebb4716c41 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -65,7 +65,8 @@ private[sql] object AvroUtils extends Logging { SchemaConverters.toSqlType( avroSchema, parsedOptions.useStableIdForUnionType, - parsedOptions.stableIdPrefixForUnionType).dataType match { + parsedOptions.stableIdPrefixForUnionType, + parsedOptions.recursiveFieldMaxDepth).dataType match { case t: StructType => Some(t) case _ => throw new RuntimeException( s"""Avro schema cannot be converted to a Spark SQL StructType: diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index b2285aa966ddb..1168a887abd8e 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -27,6 +27,10 @@ import org.apache.avro.LogicalTypes.{Date, Decimal, LocalTimestampMicros, LocalT import org.apache.avro.Schema.Type._ import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging +import org.apache.spark.internal.LogKeys.{FIELD_NAME, FIELD_TYPE, RECURSIVE_DEPTH} +import org.apache.spark.internal.MDC +import org.apache.spark.sql.avro.AvroOptions.RECURSIVE_FIELD_MAX_DEPTH_LIMIT import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.types._ import org.apache.spark.sql.types.Decimal.minBytesForPrecision @@ -36,7 +40,7 @@ import org.apache.spark.sql.types.Decimal.minBytesForPrecision * versa. */ @DeveloperApi -object SchemaConverters { +object SchemaConverters extends Logging { private lazy val nullSchema = Schema.create(Schema.Type.NULL) /** @@ -48,14 +52,27 @@ object SchemaConverters { /** * Converts an Avro schema to a corresponding Spark SQL schema. - * + * + * @param avroSchema The Avro schema to convert. + * @param useStableIdForUnionType If true, Avro schema is deserialized into Spark SQL schema, + * and the Avro Union type is transformed into a structure where + * the field names remain consistent with their respective types. + * @param stableIdPrefixForUnionType The prefix to use to configure the prefix for fields of + * Avro Union type + * @param recursiveFieldMaxDepth The maximum depth to recursively process fields in Avro schema. + * -1 means not supported. * @since 4.0.0 */ def toSqlType( avroSchema: Schema, useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String): SchemaType = { - toSqlTypeHelper(avroSchema, Set.empty, useStableIdForUnionType, stableIdPrefixForUnionType) + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int = -1): SchemaType = { + val schema = toSqlTypeHelper(avroSchema, Map.empty, useStableIdForUnionType, + stableIdPrefixForUnionType, recursiveFieldMaxDepth) + // the top level record should never return null + assert(schema != null) + schema } /** * Converts an Avro schema to a corresponding Spark SQL schema. @@ -63,17 +80,17 @@ object SchemaConverters { * @since 2.4.0 */ def toSqlType(avroSchema: Schema): SchemaType = { - toSqlType(avroSchema, false, "") + toSqlType(avroSchema, false, "", -1) } @deprecated("using toSqlType(..., useStableIdForUnionType: Boolean) instead", "4.0.0") def toSqlType(avroSchema: Schema, options: Map[String, String]): SchemaType = { val avroOptions = AvroOptions(options) - toSqlTypeHelper( + toSqlType( avroSchema, - Set.empty, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType) + avroOptions.stableIdPrefixForUnionType, + avroOptions.recursiveFieldMaxDepth) } // The property specifies Catalyst type of the given field @@ -81,9 +98,10 @@ object SchemaConverters { private def toSqlTypeHelper( avroSchema: Schema, - existingRecordNames: Set[String], + existingRecordNames: Map[String, Int], useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String): SchemaType = { + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int): SchemaType = { avroSchema.getType match { case INT => avroSchema.getLogicalType match { case _: Date => SchemaType(DateType, nullable = false) @@ -128,62 +146,110 @@ object SchemaConverters { case NULL => SchemaType(NullType, nullable = true) case RECORD => - if (existingRecordNames.contains(avroSchema.getFullName)) { + val recursiveDepth: Int = existingRecordNames.getOrElse(avroSchema.getFullName, 0) + if (recursiveDepth > 0 && recursiveFieldMaxDepth <= 0) { throw new IncompatibleSchemaException(s""" - |Found recursive reference in Avro schema, which can not be processed by Spark: - |${avroSchema.toString(true)} + |Found recursive reference in Avro schema, which can not be processed by Spark by + | default: ${avroSchema.toString(true)}. Try setting the option `recursiveFieldMaxDepth` + | to 1 - $RECURSIVE_FIELD_MAX_DEPTH_LIMIT. """.stripMargin) - } - val newRecordNames = existingRecordNames + avroSchema.getFullName - val fields = avroSchema.getFields.asScala.map { f => - val schemaType = toSqlTypeHelper( - f.schema(), - newRecordNames, - useStableIdForUnionType, - stableIdPrefixForUnionType) - StructField(f.name, schemaType.dataType, schemaType.nullable) - } + } else if (recursiveDepth > 0 && recursiveDepth >= recursiveFieldMaxDepth) { + logInfo( + log"The field ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} is dropped at recursive depth " + + log"${MDC(RECURSIVE_DEPTH, recursiveDepth)}." + ) + null + } else { + val newRecordNames = + existingRecordNames + (avroSchema.getFullName -> (recursiveDepth + 1)) + val fields = avroSchema.getFields.asScala.map { f => + val schemaType = toSqlTypeHelper( + f.schema(), + newRecordNames, + useStableIdForUnionType, + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + null + } + else { + StructField(f.name, schemaType.dataType, schemaType.nullable) + } + }.filter(_ != null).toSeq - SchemaType(StructType(fields.toArray), nullable = false) + SchemaType(StructType(fields), nullable = false) + } case ARRAY => val schemaType = toSqlTypeHelper( avroSchema.getElementType, existingRecordNames, useStableIdForUnionType, - stableIdPrefixForUnionType) - SchemaType( - ArrayType(schemaType.dataType, containsNull = schemaType.nullable), - nullable = false) + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + logInfo( + log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does not have any " + + log"fields left likely due to recursive depth limit." + ) + null + } else { + SchemaType( + ArrayType(schemaType.dataType, containsNull = schemaType.nullable), + nullable = false) + } case MAP => val schemaType = toSqlTypeHelper(avroSchema.getValueType, - existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType) - SchemaType( - MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), - nullable = false) + existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + logInfo( + log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does not have any " + + log"fields left likely due to recursive depth limit." + ) + null + } else { + SchemaType( + MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), + nullable = false) + } case UNION => if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { // In case of a union with null, eliminate it and make a recursive call val remainingUnionTypes = AvroUtils.nonNullUnionBranches(avroSchema) - if (remainingUnionTypes.size == 1) { - toSqlTypeHelper( - remainingUnionTypes.head, - existingRecordNames, - useStableIdForUnionType, - stableIdPrefixForUnionType).copy(nullable = true) + val remainingSchema = + if (remainingUnionTypes.size == 1) { + remainingUnionTypes.head + } else { + Schema.createUnion(remainingUnionTypes.asJava) + } + val schemaType = toSqlTypeHelper( + remainingSchema, + existingRecordNames, + useStableIdForUnionType, + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + + if (schemaType == null) { + logInfo( + log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does not have any " + + log"fields left likely due to recursive depth limit." + ) + null } else { - toSqlTypeHelper( - Schema.createUnion(remainingUnionTypes.asJava), - existingRecordNames, - useStableIdForUnionType, - stableIdPrefixForUnionType).copy(nullable = true) + schemaType.copy(nullable = true) } } else avroSchema.getTypes.asScala.map(_.getType).toSeq match { case Seq(t1) => toSqlTypeHelper(avroSchema.getTypes.get(0), - existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType) + existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType, + recursiveFieldMaxDepth) case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => SchemaType(LongType, nullable = false) case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => @@ -201,29 +267,33 @@ object SchemaConverters { s, existingRecordNames, useStableIdForUnionType, - stableIdPrefixForUnionType) - - val fieldName = if (useStableIdForUnionType) { - // Avro's field name may be case sensitive, so field names for two named type - // could be "a" and "A" and we need to distinguish them. In this case, we throw - // an exception. - // Stable id prefix can be empty so the name of the field can be just the type. - val tempFieldName = s"${stableIdPrefixForUnionType}${s.getName}" - if (!fieldNameSet.add(tempFieldName.toLowerCase(Locale.ROOT))) { - throw new IncompatibleSchemaException( - "Cannot generate stable identifier for Avro union type due to name " + - s"conflict of type name ${s.getName}") - } - tempFieldName + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + null } else { - s"member$i" - } + val fieldName = if (useStableIdForUnionType) { + // Avro's field name may be case sensitive, so field names for two named type + // could be "a" and "A" and we need to distinguish them. In this case, we throw + // an exception. + // Stable id prefix can be empty so the name of the field can be just the type. + val tempFieldName = s"${stableIdPrefixForUnionType}${s.getName}" + if (!fieldNameSet.add(tempFieldName.toLowerCase(Locale.ROOT))) { + throw new IncompatibleSchemaException( + "Cannot generate stable identifier for Avro union type due to name " + + s"conflict of type name ${s.getName}") + } + tempFieldName + } else { + s"member$i" + } - // All fields are nullable because only one of them is set at a time - StructField(fieldName, schemaType.dataType, nullable = true) - } + // All fields are nullable because only one of them is set at a time + StructField(fieldName, schemaType.dataType, nullable = true) + } + }.filter(_ != null).toSeq - SchemaType(StructType(fields.toArray), nullable = false) + SchemaType(StructType(fields), nullable = false) } case other => throw new IncompatibleSchemaException(s"Unsupported type $other") diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala index 1083c99160724..a13faf3b51560 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala @@ -105,7 +105,8 @@ case class AvroPartitionReaderFactory( datetimeRebaseMode, avroFilters, options.useStableIdForUnionType, - options.stableIdPrefixForUnionType) + options.stableIdPrefixForUnionType, + options.recursiveFieldMaxDepth) override val stopPosition = partitionedFile.start + partitionedFile.length override def next(): Boolean = hasNextRow diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala index fe61fe3db8786..8ec711b2757f5 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala @@ -37,7 +37,7 @@ case class AvroTable( fallbackFileFormat: Class[_ <: FileFormat]) extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { override def newScanBuilder(options: CaseInsensitiveStringMap): AvroScanBuilder = - new AvroScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + AvroScanBuilder(sparkSession, fileIndex, schema, dataSchema, mergedOptions(options)) override def inferSchema(files: Seq[FileStatus]): Option[StructType] = AvroUtils.inferSchema(sparkSession, options.asScala.toMap, files) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala index 388347537a4d6..311eda3a1b6ae 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -291,7 +291,8 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite RebaseSpec(LegacyBehaviorPolicy.CORRECTED), filters, false, - "") + "", + -1) val deserialized = deserializer.deserialize(data) expected match { case None => assert(deserialized == None) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala index 256b608feaa1f..0db9d284c4512 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala @@ -54,7 +54,7 @@ class AvroCodecSuite extends FileSourceCodecSuite { s"""CREATE TABLE avro_t |USING $format OPTIONS('compression'='unsupported') |AS SELECT 1 as id""".stripMargin)), - errorClass = "CODEC_SHORT_NAME_NOT_FOUND", + condition = "CODEC_SHORT_NAME_NOT_FOUND", sqlState = Some("42704"), parameters = Map("codecName" -> "unsupported") ) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala index 432c3fa9be3ac..a7f7abadcf485 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.functions.{col, lit, struct} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{BinaryType, StructType} +import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} class AvroFunctionsSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -329,7 +329,7 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { s""" |select to_avro(s, 42) as result from t |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map("sqlExpr" -> "\"to_avro(s, 42)\"", "msg" -> ("The second argument of the TO_AVRO SQL function must be a constant string " + "containing the JSON representation of the schema to use for converting the value to " + @@ -344,7 +344,7 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { s""" |select from_avro(s, 42, '') as result from t |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map("sqlExpr" -> "\"from_avro(s, 42, )\"", "msg" -> ("The second argument of the FROM_AVRO SQL function must be a constant string " + "containing the JSON representation of the schema to use for converting the value " + @@ -359,7 +359,7 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { s""" |select from_avro(s, '$jsonFormatSchema', 42) as result from t |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map( "sqlExpr" -> s"\"from_avro(s, $jsonFormatSchema, 42)\"".stripMargin, @@ -374,6 +374,37 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { } } + + test("roundtrip in to_avro and from_avro - recursive schema") { + val catalystSchema = + StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType))))))))) + + val avroSchema = s""" + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [ + | {"name": "Id", "type": "int"}, + | {"name": "Name", "type": ["null", "test_schema"]} + | ] + |} + """.stripMargin + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(2, Row(3, Row(4))), Row(1, null))), + catalystSchema).select(struct("Id", "Name").as("struct")) + + val avroStructDF = df.select(functions.to_avro($"struct", avroSchema).as("avro")) + checkAnswer(avroStructDF.select( + functions.from_avro($"avro", avroSchema, Map( + "recursiveFieldMaxDepth" -> "3").asJava)), df) + } + private def serialize(record: GenericRecord, avroSchema: String): Array[Byte] = { val schema = new Schema.Parser().parse(avroSchema) val datumWriter = new GenericDatumWriter[GenericRecord](schema) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala index 429f3c0deca6a..751ac275e048a 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala @@ -439,7 +439,7 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession { assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[SparkArithmeticException], - errorClass = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION", + condition = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION", parameters = Map( "value" -> "0", "precision" -> "4", diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala index 9b3bb929a700d..c1ab96a63eb26 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala @@ -77,7 +77,8 @@ class AvroRowReaderSuite RebaseSpec(CORRECTED), new NoopFilters, false, - "") + "", + -1) override val stopPosition = fileSize override def hasNext: Boolean = hasNextRow diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala index cbcbc2e7e76a6..3643a95abe19c 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala @@ -228,7 +228,8 @@ object AvroSerdeSuite { RebaseSpec(CORRECTED), new NoopFilters, false, - "") + "", + -1) } /** diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index b20ee4b3cc231..be887bd5237b0 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -894,7 +894,7 @@ abstract class AvroSuite assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[AnalysisException], - errorClass = "AVRO_INCOMPATIBLE_READ_TYPE", + condition = "AVRO_INCOMPATIBLE_READ_TYPE", parameters = Map("avroPath" -> "field 'a'", "sqlPath" -> "field 'a'", "avroType" -> "decimal\\(12,10\\)", @@ -972,7 +972,7 @@ abstract class AvroSuite assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[AnalysisException], - errorClass = "AVRO_INCOMPATIBLE_READ_TYPE", + condition = "AVRO_INCOMPATIBLE_READ_TYPE", parameters = Map("avroPath" -> "field 'a'", "sqlPath" -> "field 'a'", "avroType" -> "interval day to second", @@ -1009,7 +1009,7 @@ abstract class AvroSuite assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[AnalysisException], - errorClass = "AVRO_INCOMPATIBLE_READ_TYPE", + condition = "AVRO_INCOMPATIBLE_READ_TYPE", parameters = Map("avroPath" -> "field 'a'", "sqlPath" -> "field 'a'", "avroType" -> "interval year to month", @@ -1673,7 +1673,7 @@ abstract class AvroSuite exception = intercept[AnalysisException] { sql("select interval 1 days").write.format("avro").mode("overwrite").save(tempDir) }, - errorClass = "_LEGACY_ERROR_TEMP_1136", + condition = "_LEGACY_ERROR_TEMP_1136", parameters = Map.empty ) checkError( @@ -1681,7 +1681,7 @@ abstract class AvroSuite spark.udf.register("testType", () => new IntervalData()) sql("select testType()").write.format("avro").mode("overwrite").save(tempDir) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`testType()`", "columnType" -> "UDT(\"INTERVAL\")", @@ -2220,7 +2220,8 @@ abstract class AvroSuite } } - private def checkSchemaWithRecursiveLoop(avroSchema: String): Unit = { + private def checkSchemaWithRecursiveLoop(avroSchema: String, recursiveFieldMaxDepth: Int): + Unit = { val message = intercept[IncompatibleSchemaException] { SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema), false, "") }.getMessage @@ -2229,7 +2230,79 @@ abstract class AvroSuite } test("Detect recursive loop") { - checkSchemaWithRecursiveLoop(""" + for (recursiveFieldMaxDepth <- Seq(-1, 0)) { + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, // each element has a long + | {"name": "next", "type": ["null", "LongList"]} // optional next element + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields": [ + | { + | "name": "value", + | "type": { + | "type": "record", + | "name": "foo", + | "fields": [ + | { + | "name": "parent", + | "type": "LongList" + | } + | ] + | } + | } + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, + | {"name": "array", "type": {"type": "array", "items": "LongList"}} + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, + | {"name": "map", "type": {"type": "map", "values": "LongList"}} + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + } + } + + private def checkSparkSchemaEquals( + avroSchema: String, expectedSchema: StructType, recursiveFieldMaxDepth: Int): Unit = { + val sparkSchema = + SchemaConverters.toSqlType( + new Schema.Parser().parse(avroSchema), false, "", recursiveFieldMaxDepth).dataType + + assert(sparkSchema === expectedSchema) + } + + test("Translate recursive schema - union") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2238,9 +2311,57 @@ abstract class AvroSuite | {"name": "next", "type": ["null", "LongList"]} // optional next element | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = nonRecursiveFields.add("next", expectedSchema) + } + } + + test("Translate recursive schema - union - 2 non-null fields") { + val avroSchema = """ + |{ + | "type": "record", + | "name": "TreeNode", + | "fields": [ + | { + | "name": "name", + | "type": "string" + | }, + | { + | "name": "value", + | "type": [ + | "long" + | ] + | }, + | { + | "name": "children", + | "type": [ + | "null", + | { + | "type": "array", + | "items": "TreeNode" + | } + | ], + | "default": null + | } + | ] + |} + """.stripMargin + val nonRecursiveFields = new StructType().add("name", StringType, nullable = false) + .add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = nonRecursiveFields.add("children", + new ArrayType(expectedSchema, false), nullable = true) + } + } - checkSchemaWithRecursiveLoop(""" + test("Translate recursive schema - record") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2260,9 +2381,18 @@ abstract class AvroSuite | } | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", StructType(Seq()), nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = new StructType().add("value", + new StructType().add("parent", expectedSchema, nullable = false), nullable = false) + } + } - checkSchemaWithRecursiveLoop(""" + test("Translate recursive schema - array") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2271,9 +2401,18 @@ abstract class AvroSuite | {"name": "array", "type": {"type": "array", "items": "LongList"}} | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = + nonRecursiveFields.add("array", new ArrayType(expectedSchema, false), nullable = false) + } + } - checkSchemaWithRecursiveLoop(""" + test("Translate recursive schema - map") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2282,7 +2421,70 @@ abstract class AvroSuite | {"name": "map", "type": {"type": "map", "values": "LongList"}} | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = + nonRecursiveFields.add("map", + new MapType(StringType, expectedSchema, false), nullable = false) + } + } + + test("recursive schema integration test") { + val catalystSchema = + StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", NullType))))))))) + + val avroSchema = s""" + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [ + | {"name": "Id", "type": "int"}, + | {"name": "Name", "type": ["null", "test_schema"]} + | ] + |} + """.stripMargin + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(2, Row(3, Row(4, null))), Row(1, null))), + catalystSchema) + + withTempPath { tempDir => + df.write.format("avro").save(tempDir.getPath) + + val exc = intercept[AnalysisException] { + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .option("recursiveFieldMaxDepth", 16) + .load(tempDir.getPath) + } + assert(exc.getMessage.contains("Should not be greater than 15.")) + + checkAnswer( + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .option("recursiveFieldMaxDepth", 10) + .load(tempDir.getPath), + df) + + checkAnswer( + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .option("recursiveFieldMaxDepth", 1) + .load(tempDir.getPath), + df.select("Id")) + } } test("log a warning of ignoreExtension deprecation") { @@ -2726,7 +2928,7 @@ abstract class AvroSuite |LOCATION '${dir}' |AS SELECT ID, IF(ID=1,1,0) FROM v""".stripMargin) }, - errorClass = "INVALID_COLUMN_NAME_AS_PATH", + condition = "INVALID_COLUMN_NAME_AS_PATH", parameters = Map( "datasource" -> "AvroFileFormat", "columnName" -> "`(IF((ID = 1), 1, 0))`") ) @@ -2777,7 +2979,7 @@ abstract class AvroSuite } test("SPARK-40667: validate Avro Options") { - assert(AvroOptions.getAllOptions.size == 11) + assert(AvroOptions.getAllOptions.size == 12) // Please add validation on any new Avro options here assert(AvroOptions.isValidOption("ignoreExtension")) assert(AvroOptions.isValidOption("mode")) @@ -2790,6 +2992,7 @@ abstract class AvroSuite assert(AvroOptions.isValidOption("datetimeRebaseMode")) assert(AvroOptions.isValidOption("enableStableIdentifiersForUnionType")) assert(AvroOptions.isValidOption("stableIdentifierPrefixForUnionType")) + assert(AvroOptions.isValidOption("recursiveFieldMaxDepth")) } test("SPARK-46633: read file with empty blocks") { @@ -2831,7 +3034,7 @@ class AvroV1Suite extends AvroSuite { sql("SELECT ID, IF(ID=1,1,0) FROM v").write.mode(SaveMode.Overwrite) .format("avro").save(dir.getCanonicalPath) }, - errorClass = "INVALID_COLUMN_NAME_AS_PATH", + condition = "INVALID_COLUMN_NAME_AS_PATH", parameters = Map( "datasource" -> "AvroFileFormat", "columnName" -> "`(IF((ID = 1), 1, 0))`") ) @@ -2844,7 +3047,7 @@ class AvroV1Suite extends AvroSuite { .write.mode(SaveMode.Overwrite) .format("avro").save(dir.getCanonicalPath) }, - errorClass = "INVALID_COLUMN_NAME_AS_PATH", + condition = "INVALID_COLUMN_NAME_AS_PATH", parameters = Map( "datasource" -> "AvroFileFormat", "columnName" -> "`(IF((ID = 1), 1, 0))`") ) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index c06cbbc0cdb42..3777f82594aae 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -22,6 +22,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto.{NAReplace, Relation} import org.apache.spark.connect.proto.Expression.{Literal => GLiteral} import org.apache.spark.connect.proto.NAReplace.Replacement +import org.apache.spark.sql.connect.ConnectConversions._ /** * Functionality for working with missing data in `DataFrame`s. @@ -29,7 +30,7 @@ import org.apache.spark.connect.proto.NAReplace.Replacement * @since 3.4.0 */ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: Relation) - extends api.DataFrameNaFunctions[Dataset] { + extends api.DataFrameNaFunctions { import sparkSession.RichColumn override protected def drop(minNonNulls: Option[Int]): Dataset[Row] = diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 1ad98dc91b216..60bacd4e18ede 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -23,11 +23,8 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable import org.apache.spark.connect.proto.Parse.ParseFormat -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, SparkCharVarcharUtils} +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.common.DataTypeProtoConverter -import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.types.StructType /** @@ -37,144 +34,44 @@ import org.apache.spark.sql.types.StructType * @since 3.4.0 */ @Stable -class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging { - - /** - * Specifies the input data source format. - * - * @since 3.4.0 - */ - def format(source: String): DataFrameReader = { - this.source = source - this - } +class DataFrameReader private[sql] (sparkSession: SparkSession) extends api.DataFrameReader { + type DS[U] = Dataset[U] - /** - * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema - * automatically from data. By specifying the schema here, the underlying data source can skip - * the schema inference step, and thus speed up data loading. - * - * @since 3.4.0 - */ - def schema(schema: StructType): DataFrameReader = { - if (schema != null) { - val replaced = SparkCharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] - this.userSpecifiedSchema = Option(replaced) - } - this - } + /** @inheritdoc */ + override def format(source: String): this.type = super.format(source) - /** - * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) - * can infer the input schema automatically from data. By specifying the schema here, the - * underlying data source can skip the schema inference step, and thus speed up data loading. - * - * {{{ - * spark.read.schema("a INT, b STRING, c DOUBLE").csv("test.csv") - * }}} - * - * @since 3.4.0 - */ - def schema(schemaString: String): DataFrameReader = { - schema(StructType.fromDDL(schemaString)) - } + /** @inheritdoc */ + override def schema(schema: StructType): this.type = super.schema(schema) - /** - * Adds an input option for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. If a new option - * has the same key case-insensitively, it will override the existing option. - * - * @since 3.4.0 - */ - def option(key: String, value: String): DataFrameReader = { - this.extraOptions = this.extraOptions + (key -> value) - this - } + /** @inheritdoc */ + override def schema(schemaString: String): this.type = super.schema(schemaString) - /** - * Adds an input option for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. If a new option - * has the same key case-insensitively, it will override the existing option. - * - * @since 3.4.0 - */ - def option(key: String, value: Boolean): DataFrameReader = option(key, value.toString) - - /** - * Adds an input option for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. If a new option - * has the same key case-insensitively, it will override the existing option. - * - * @since 3.4.0 - */ - def option(key: String, value: Long): DataFrameReader = option(key, value.toString) - - /** - * Adds an input option for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. If a new option - * has the same key case-insensitively, it will override the existing option. - * - * @since 3.4.0 - */ - def option(key: String, value: Double): DataFrameReader = option(key, value.toString) - - /** - * (Scala-specific) Adds input options for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. If a new option - * has the same key case-insensitively, it will override the existing option. - * - * @since 3.4.0 - */ - def options(options: scala.collection.Map[String, String]): DataFrameReader = { - this.extraOptions ++= options - this - } + /** @inheritdoc */ + override def option(key: String, value: String): this.type = super.option(key, value) - /** - * Adds input options for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. If a new option - * has the same key case-insensitively, it will override the existing option. - * - * @since 3.4.0 - */ - def options(options: java.util.Map[String, String]): DataFrameReader = { - this.options(options.asScala) - this - } + /** @inheritdoc */ + override def option(key: String, value: Boolean): this.type = super.option(key, value) - /** - * Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external - * key-value stores). - * - * @since 3.4.0 - */ - def load(): DataFrame = { - load(Seq.empty: _*) // force invocation of `load(...varargs...)` - } + /** @inheritdoc */ + override def option(key: String, value: Long): this.type = super.option(key, value) - /** - * Loads input in as a `DataFrame`, for data sources that require a path (e.g. data backed by a - * local or distributed file system). - * - * @since 3.4.0 - */ - def load(path: String): DataFrame = { - // force invocation of `load(...varargs...)` - load(Seq(path): _*) - } + /** @inheritdoc */ + override def option(key: String, value: Double): this.type = super.option(key, value) + + /** @inheritdoc */ + override def options(options: scala.collection.Map[String, String]): this.type = + super.options(options) - /** - * Loads input in as a `DataFrame`, for data sources that support multiple paths. Only works if - * the source is a HadoopFsRelationProvider. - * - * @since 3.4.0 - */ + /** @inheritdoc */ + override def options(options: java.util.Map[String, String]): this.type = super.options(options) + + /** @inheritdoc */ + override def load(): DataFrame = load(Nil: _*) + + /** @inheritdoc */ + def load(path: String): DataFrame = load(Seq(path): _*) + + /** @inheritdoc */ @scala.annotation.varargs def load(paths: String*): DataFrame = { sparkSession.newDataFrame { builder => @@ -190,93 +87,29 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging } } - /** - * Construct a `DataFrame` representing the database table accessible via JDBC URL url named - * table and connection properties. - * - * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC - * in - * Data Source Option in the version you use. - * - * @since 3.4.0 - */ - def jdbc(url: String, table: String, properties: Properties): DataFrame = { - // properties should override settings in extraOptions. - this.extraOptions ++= properties.asScala - // explicit url and dbtable should override all - this.extraOptions ++= Seq("url" -> url, "dbtable" -> table) - format("jdbc").load() - } + /** @inheritdoc */ + override def jdbc(url: String, table: String, properties: Properties): DataFrame = + super.jdbc(url, table, properties) - // scalastyle:off line.size.limit - /** - * Construct a `DataFrame` representing the database table accessible via JDBC URL url named - * table. Partitions of the table will be retrieved in parallel based on the parameters passed - * to this function. - * - * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash - * your external database systems. - * - * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC - * in - * Data Source Option in the version you use. - * - * @param table - * Name of the table in the external database. - * @param columnName - * Alias of `partitionColumn` option. Refer to `partitionColumn` in - * Data Source Option in the version you use. - * @param connectionProperties - * JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least - * a "user" and "password" property should be included. "fetchsize" can be used to control the - * number of rows per fetch and "queryTimeout" can be used to wait for a Statement object to - * execute to the given number of seconds. - * @since 3.4.0 - */ - // scalastyle:on line.size.limit - def jdbc( + /** @inheritdoc */ + override def jdbc( url: String, table: String, columnName: String, lowerBound: Long, upperBound: Long, numPartitions: Int, - connectionProperties: Properties): DataFrame = { - // columnName, lowerBound, upperBound and numPartitions override settings in extraOptions. - this.extraOptions ++= Map( - "partitionColumn" -> columnName, - "lowerBound" -> lowerBound.toString, - "upperBound" -> upperBound.toString, - "numPartitions" -> numPartitions.toString) - jdbc(url, table, connectionProperties) - } - - /** - * Construct a `DataFrame` representing the database table accessible via JDBC URL url named - * table using connection properties. The `predicates` parameter gives a list expressions - * suitable for inclusion in WHERE clauses; each one defines one partition of the `DataFrame`. - * - * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash - * your external database systems. - * - * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC - * in - * Data Source Option in the version you use. - * - * @param table - * Name of the table in the external database. - * @param predicates - * Condition in the where clause for each partition. - * @param connectionProperties - * JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least - * a "user" and "password" property should be included. "fetchsize" can be used to control the - * number of rows per fetch. - * @since 3.4.0 - */ + connectionProperties: Properties): DataFrame = + super.jdbc( + url, + table, + columnName, + lowerBound, + upperBound, + numPartitions, + connectionProperties) + + /** @inheritdoc */ def jdbc( url: String, table: String, @@ -296,207 +129,56 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging } } - /** - * Loads a JSON file and returns the results as a `DataFrame`. - * - * See the documentation on the overloaded `json()` method with varargs for more details. - * - * @since 3.4.0 - */ - def json(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - json(Seq(path): _*) - } + /** @inheritdoc */ + override def json(path: String): DataFrame = super.json(path) - /** - * Loads JSON files and returns the results as a `DataFrame`. - * - * JSON Lines (newline-delimited JSON) is supported by - * default. For JSON (one record per file), set the `multiLine` option to true. - * - * This function goes through the input once to determine the input schema. If you know the - * schema in advance, use the version that specifies the schema to avoid the extra scan. - * - * You can find the JSON-specific options for reading JSON files in - * Data Source Option in the version you use. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def json(paths: String*): DataFrame = { - format("json").load(paths: _*) - } + override def json(paths: String*): DataFrame = super.json(paths: _*) - /** - * Loads a `Dataset[String]` storing JSON objects (JSON Lines - * text format or newline-delimited JSON) and returns the result as a `DataFrame`. - * - * Unless the schema is specified using `schema` function, this function goes through the input - * once to determine the input schema. - * - * @param jsonDataset - * input Dataset with one JSON object per record - * @since 3.4.0 - */ + /** @inheritdoc */ def json(jsonDataset: Dataset[String]): DataFrame = parse(jsonDataset, ParseFormat.PARSE_FORMAT_JSON) - /** - * Loads a CSV file and returns the result as a `DataFrame`. See the documentation on the other - * overloaded `csv()` method for more details. - * - * @since 3.4.0 - */ - def csv(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - csv(Seq(path): _*) - } + /** @inheritdoc */ + override def csv(path: String): DataFrame = super.csv(path) - /** - * Loads CSV files and returns the result as a `DataFrame`. - * - * This function will go through the input once to determine the input schema if `inferSchema` - * is enabled. To avoid going through the entire data once, disable `inferSchema` option or - * specify the schema explicitly using `schema`. - * - * You can find the CSV-specific options for reading CSV files in - * Data Source Option in the version you use. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def csv(paths: String*): DataFrame = format("csv").load(paths: _*) - - /** - * Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`. - * - * If the schema is not specified using `schema` function and `inferSchema` option is enabled, - * this function goes through the input once to determine the input schema. - * - * If the schema is not specified using `schema` function and `inferSchema` option is disabled, - * it determines the columns as string types and it reads only the first line to determine the - * names and the number of fields. - * - * If the enforceSchema is set to `false`, only the CSV header in the first line is checked to - * conform specified or inferred schema. - * - * @note - * if `header` option is set to `true` when calling this API, all lines same with the header - * will be removed if exists. - * @param csvDataset - * input Dataset with one CSV row per record - * @since 3.4.0 - */ + override def csv(paths: String*): DataFrame = super.csv(paths: _*) + + /** @inheritdoc */ def csv(csvDataset: Dataset[String]): DataFrame = parse(csvDataset, ParseFormat.PARSE_FORMAT_CSV) - /** - * Loads a XML file and returns the result as a `DataFrame`. See the documentation on the other - * overloaded `xml()` method for more details. - * - * @since 4.0.0 - */ - def xml(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - xml(Seq(path): _*) - } + /** @inheritdoc */ + override def xml(path: String): DataFrame = super.xml(path) - /** - * Loads XML files and returns the result as a `DataFrame`. - * - * This function will go through the input once to determine the input schema if `inferSchema` - * is enabled. To avoid going through the entire data once, disable `inferSchema` option or - * specify the schema explicitly using `schema`. - * - * You can find the XML-specific options for reading XML files in - * Data Source Option in the version you use. - * - * @since 4.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def xml(paths: String*): DataFrame = format("xml").load(paths: _*) - - /** - * Loads an `Dataset[String]` storing XML object and returns the result as a `DataFrame`. - * - * If the schema is not specified using `schema` function and `inferSchema` option is enabled, - * this function goes through the input once to determine the input schema. - * - * @param xmlDataset - * input Dataset with one XML object per record - * @since 4.0.0 - */ + override def xml(paths: String*): DataFrame = super.xml(paths: _*) + + /** @inheritdoc */ def xml(xmlDataset: Dataset[String]): DataFrame = parse(xmlDataset, ParseFormat.PARSE_FORMAT_UNSPECIFIED) - /** - * Loads a Parquet file, returning the result as a `DataFrame`. See the documentation on the - * other overloaded `parquet()` method for more details. - * - * @since 3.4.0 - */ - def parquet(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - parquet(Seq(path): _*) - } + /** @inheritdoc */ + override def parquet(path: String): DataFrame = super.parquet(path) - /** - * Loads a Parquet file, returning the result as a `DataFrame`. - * - * Parquet-specific option(s) for reading Parquet files can be found in Data - * Source Option in the version you use. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def parquet(paths: String*): DataFrame = { - format("parquet").load(paths: _*) - } + override def parquet(paths: String*): DataFrame = super.parquet(paths: _*) - /** - * Loads an ORC file and returns the result as a `DataFrame`. - * - * @param path - * input path - * @since 3.4.0 - */ - def orc(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - orc(Seq(path): _*) - } + /** @inheritdoc */ + override def orc(path: String): DataFrame = super.orc(path) - /** - * Loads ORC files and returns the result as a `DataFrame`. - * - * ORC-specific option(s) for reading ORC files can be found in Data - * Source Option in the version you use. - * - * @param paths - * input paths - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def orc(paths: String*): DataFrame = format("orc").load(paths: _*) - - /** - * Returns the specified table/view as a `DataFrame`. If it's a table, it must support batch - * reading and the returned DataFrame is the batch scan query plan of this table. If it's a - * view, the returned DataFrame is simply the query plan of the view, which can either be a - * batch or streaming query plan. - * - * @param tableName - * is either a qualified or unqualified name that designates a table or view. If a database is - * specified, it identifies the table/view from the database. Otherwise, it first attempts to - * find a temporary view with the given name and then match the table/view from the current - * database. Note that, the global temporary view database is also valid here. - * @since 3.4.0 - */ + override def orc(paths: String*): DataFrame = super.orc(paths: _*) + + /** @inheritdoc */ def table(tableName: String): DataFrame = { + assertNoSpecifiedSchema("table") sparkSession.newDataFrame { builder => builder.getReadBuilder.getNamedTableBuilder .setUnparsedIdentifier(tableName) @@ -504,80 +186,19 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging } } - /** - * Loads text files and returns a `DataFrame` whose schema starts with a string column named - * "value", and followed by partitioned columns if there are any. See the documentation on the - * other overloaded `text()` method for more details. - * - * @since 3.4.0 - */ - def text(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - text(Seq(path): _*) - } + /** @inheritdoc */ + override def text(path: String): DataFrame = super.text(path) - /** - * Loads text files and returns a `DataFrame` whose schema starts with a string column named - * "value", and followed by partitioned columns if there are any. The text files must be encoded - * as UTF-8. - * - * By default, each line in the text files is a new row in the resulting DataFrame. For example: - * {{{ - * // Scala: - * spark.read.text("/path/to/spark/README.md") - * - * // Java: - * spark.read().text("/path/to/spark/README.md") - * }}} - * - * You can find the text-specific options for reading text files in - * Data Source Option in the version you use. - * - * @param paths - * input paths - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def text(paths: String*): DataFrame = format("text").load(paths: _*) - - /** - * Loads text files and returns a [[Dataset]] of String. See the documentation on the other - * overloaded `textFile()` method for more details. - * @since 3.4.0 - */ - def textFile(path: String): Dataset[String] = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - textFile(Seq(path): _*) - } + override def text(paths: String*): DataFrame = super.text(paths: _*) - /** - * Loads text files and returns a [[Dataset]] of String. The underlying schema of the Dataset - * contains a single string column named "value". The text files must be encoded as UTF-8. - * - * If the directory structure of the text files contains partitioning information, those are - * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. - * - * By default, each line in the text files is a new row in the resulting DataFrame. For example: - * {{{ - * // Scala: - * spark.read.textFile("/path/to/spark/README.md") - * - * // Java: - * spark.read().textFile("/path/to/spark/README.md") - * }}} - * - * You can set the text-specific options as specified in `DataFrameReader.text`. - * - * @param paths - * input path - * @since 3.4.0 - */ + /** @inheritdoc */ + override def textFile(path: String): Dataset[String] = super.textFile(path) + + /** @inheritdoc */ @scala.annotation.varargs - def textFile(paths: String*): Dataset[String] = { - assertNoSpecifiedSchema("textFile") - text(paths: _*).select("value").as(StringEncoder) - } + override def textFile(paths: String*): Dataset[String] = super.textFile(paths: _*) private def assertSourceFormatSpecified(): Unit = { if (source == null) { @@ -597,24 +218,4 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging } } } - - /** - * A convenient function for schema validation in APIs. - */ - private def assertNoSpecifiedSchema(operation: String): Unit = { - if (userSpecifiedSchema.nonEmpty) { - throw DataTypeErrors.userSpecifiedSchemaUnsupportedError(operation) - } - } - - /////////////////////////////////////////////////////////////////////////////////////// - // Builder pattern config options - /////////////////////////////////////////////////////////////////////////////////////// - - private var source: String = _ - - private var userSpecifiedSchema: Option[StructType] = None - - private var extraOptions = CaseInsensitiveMap[String](Map.empty) - } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 9f5ada0d7ec35..bb7cfa75a9ab9 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -22,6 +22,7 @@ import java.{lang => jl, util => ju} import org.apache.spark.connect.proto.{Relation, StatSampleBy} import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, PrimitiveDoubleEncoder} +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.functions.lit /** @@ -30,7 +31,7 @@ import org.apache.spark.sql.functions.lit * @since 3.4.0 */ final class DataFrameStatFunctions private[sql] (protected val df: DataFrame) - extends api.DataFrameStatFunctions[Dataset] { + extends api.DataFrameStatFunctions { private def root: Relation = df.plan.getRoot private val sparkSession: SparkSession = df.sparkSession diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index ce21f18501a79..accfff9f2b073 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -32,12 +32,13 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.OrderUtils +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.client.SparkResult -import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter, UdfUtils} +import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter} import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.functions.{struct, to_json} -import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter, DataFrameWriterImpl, UnresolvedAttribute, UnresolvedRegex} +import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter, DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, ToScalaUDF, UDFAdaptors, UnresolvedAttribute, UnresolvedRegex} import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types.{Metadata, StructType} import org.apache.spark.storage.StorageLevel @@ -134,8 +135,8 @@ class Dataset[T] private[sql] ( val sparkSession: SparkSession, @DeveloperApi val plan: proto.Plan, val encoder: Encoder[T]) - extends api.Dataset[T, Dataset] { - type RGD = RelationalGroupedDataset + extends api.Dataset[T] { + type DS[U] = Dataset[U] import sparkSession.RichColumn @@ -481,7 +482,7 @@ class Dataset[T] private[sql] ( val unpivot = builder.getUnpivotBuilder .setInput(plan.getRoot) .addAllIds(ids.toImmutableArraySeq.map(_.expr).asJava) - .setValueColumnName(variableColumnName) + .setVariableColumnName(variableColumnName) .setValueColumnName(valueColumnName) valuesOption.foreach { values => unpivot.getValuesBuilder @@ -489,6 +490,14 @@ class Dataset[T] private[sql] ( } } + private def buildTranspose(indices: Seq[Column]): DataFrame = + sparkSession.newDataFrame { builder => + val transpose = builder.getTransposeBuilder.setInput(plan.getRoot) + indices.foreach { indexColumn => + transpose.addIndexColumns(indexColumn.expr) + } + } + /** @inheritdoc */ @scala.annotation.varargs def groupBy(cols: Column*): RelationalGroupedDataset = { @@ -515,27 +524,11 @@ class Dataset[T] private[sql] ( result(0) } - /** - * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given - * key `func`. - * - * @group typedrel - * @since 3.5.0 - */ + /** @inheritdoc */ def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { KeyValueGroupedDatasetImpl[K, T](this, encoderFor[K], func) } - /** - * (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given - * key `func`. - * - * @group typedrel - * @since 3.5.0 - */ - def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = - groupByKey(UdfUtils.mapFunctionToScalaFunc(func))(encoder) - /** @inheritdoc */ @scala.annotation.varargs def rollup(cols: Column*): RelationalGroupedDataset = { @@ -582,6 +575,14 @@ class Dataset[T] private[sql] ( buildUnpivot(ids, None, variableColumnName, valueColumnName) } + /** @inheritdoc */ + def transpose(indexColumn: Column): DataFrame = + buildTranspose(Seq(indexColumn)) + + /** @inheritdoc */ + def transpose(): DataFrame = + buildTranspose(Seq.empty) + /** @inheritdoc */ def limit(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder => builder.getLimitBuilder @@ -865,17 +866,17 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ def filter(f: FilterFunction[T]): Dataset[T] = { - filter(UdfUtils.filterFuncToScalaFunc(f)) + filter(ToScalaUDF(f)) } /** @inheritdoc */ def map[U: Encoder](f: T => U): Dataset[U] = { - mapPartitions(UdfUtils.mapFuncToMapPartitionsAdaptor(f)) + mapPartitions(UDFAdaptors.mapToMapPartitions(f)) } /** @inheritdoc */ def map[U](f: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { - map(UdfUtils.mapFunctionToScalaFunc(f))(encoder) + mapPartitions(UDFAdaptors.mapToMapPartitions(f))(encoder) } /** @inheritdoc */ @@ -892,25 +893,11 @@ class Dataset[T] private[sql] ( } } - /** @inheritdoc */ - def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { - mapPartitions(UdfUtils.mapPartitionsFuncToScalaFunc(f))(encoder) - } - - /** @inheritdoc */ - override def flatMap[U: Encoder](func: T => IterableOnce[U]): Dataset[U] = - mapPartitions(UdfUtils.flatMapFuncToMapPartitionsAdaptor(func)) - - /** @inheritdoc */ - override def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { - flatMap(UdfUtils.flatMapFuncToScalaFunc(f))(encoder) - } - /** @inheritdoc */ @deprecated("use flatMap() or select() with functions.explode() instead", "3.5.0") def explode[A <: Product: TypeTag](input: Column*)(f: Row => IterableOnce[A]): DataFrame = { val generator = SparkUserDefinedFunction( - UdfUtils.iterableOnceToSeq(f), + UDFAdaptors.iterableOnceToSeq(f), UnboundRowEncoder :: Nil, ScalaReflection.encoderFor[Seq[A]]) select(col("*"), functions.inline(generator(struct(input: _*)))) @@ -921,31 +908,16 @@ class Dataset[T] private[sql] ( def explode[A, B: TypeTag](inputColumn: String, outputColumn: String)( f: A => IterableOnce[B]): DataFrame = { val generator = SparkUserDefinedFunction( - UdfUtils.iterableOnceToSeq(f), + UDFAdaptors.iterableOnceToSeq(f), Nil, ScalaReflection.encoderFor[Seq[B]]) select(col("*"), functions.explode(generator(col(inputColumn))).as((outputColumn))) } - /** @inheritdoc */ - def foreach(f: T => Unit): Unit = { - foreachPartition(UdfUtils.foreachFuncToForeachPartitionsAdaptor(f)) - } - - /** @inheritdoc */ - override def foreach(func: ForeachFunction[T]): Unit = - foreach(UdfUtils.foreachFuncToScalaFunc(func)) - /** @inheritdoc */ def foreachPartition(f: Iterator[T] => Unit): Unit = { // Delegate to mapPartition with empty result. - mapPartitions(UdfUtils.foreachPartitionFuncToMapPartitionsAdaptor(f))(RowEncoder(Seq.empty)) - .collect() - } - - /** @inheritdoc */ - override def foreachPartition(func: ForeachPartitionFunction[T]): Unit = { - foreachPartition(UdfUtils.foreachPartitionFuncToScalaFunc(func)) + mapPartitions(UDFAdaptors.foreachPartitionToMapPartitions(f))(NullEncoder).collect() } /** @inheritdoc */ @@ -1047,51 +1019,12 @@ class Dataset[T] private[sql] ( new DataFrameWriterImpl[T](this) } - /** - * Create a write configuration builder for v2 sources. - * - * This builder is used to configure and execute write operations. For example, to append to an - * existing table, run: - * - * {{{ - * df.writeTo("catalog.db.table").append() - * }}} - * - * This can also be used to create or replace existing tables: - * - * {{{ - * df.writeTo("catalog.db.table").partitionedBy($"col").createOrReplace() - * }}} - * - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ def writeTo(table: String): DataFrameWriterV2[T] = { - new DataFrameWriterV2[T](table, this) + new DataFrameWriterV2Impl[T](table, this) } - /** - * Merges a set of updates, insertions, and deletions based on a source table into a target - * table. - * - * Scala Examples: - * {{{ - * spark.table("source") - * .mergeInto("target", $"source.id" === $"target.id") - * .whenMatched($"salary" === 100) - * .delete() - * .whenNotMatched() - * .insertAll() - * .whenNotMatchedBySource($"salary" === 100) - * .update(Map( - * "salary" -> lit(200) - * )) - * .merge() - * }}} - * - * @group basic - * @since 4.0.0 - */ + /** @inheritdoc */ def mergeInto(table: String, condition: Column): MergeIntoWriter[T] = { if (isStreaming) { throw new AnalysisException( @@ -1099,7 +1032,7 @@ class Dataset[T] private[sql] ( messageParameters = Map("methodName" -> toSQLId("mergeInto"))) } - new MergeIntoWriter[T](table, this, condition) + new MergeIntoWriterImpl[T](table, this, condition) } /** @@ -1464,6 +1397,22 @@ class Dataset[T] private[sql] ( override def dropDuplicatesWithinWatermark(col1: String, cols: String*): Dataset[T] = super.dropDuplicatesWithinWatermark(col1, cols: _*) + /** @inheritdoc */ + override def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = + super.mapPartitions(f, encoder) + + /** @inheritdoc */ + override def flatMap[U: Encoder](func: T => IterableOnce[U]): Dataset[U] = + super.flatMap(func) + + /** @inheritdoc */ + override def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = + super.flatMap(f, encoder) + + /** @inheritdoc */ + override def foreachPartition(func: ForeachPartitionFunction[T]): Unit = + super.foreachPartition(func) + /** @inheritdoc */ @scala.annotation.varargs override def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = @@ -1515,4 +1464,10 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ @scala.annotation.varargs override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) + + /** @inheritdoc */ + override def groupByKey[K]( + func: MapFunction[T, K], + encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = + super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]] } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 04b620bdf8f98..6bf2518901470 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -19,17 +19,19 @@ package org.apache.spark.sql import java.util.Arrays +import scala.annotation.unused import scala.jdk.CollectionConverters._ -import scala.language.existentials import org.apache.spark.api.java.function._ import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.common.UdfUtils import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.toExpr +import org.apache.spark.sql.internal.UDFAdaptors import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeMode} /** @@ -39,7 +41,10 @@ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode * * @since 3.5.0 */ -class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { +class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDataset[K, V] { + type KVDS[KY, VL] = KeyValueGroupedDataset[KY, VL] + + private def unsupported(): Nothing = throw new UnsupportedOperationException() /** * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the @@ -48,499 +53,52 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { * * @since 3.5.0 */ - def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = { - throw new UnsupportedOperationException - } + def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = unsupported() - /** - * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to - * the data. The grouping key is unchanged by this. - * - * {{{ - * // Create values grouped by key from a Dataset[(K, V)] - * ds.groupByKey(_._1).mapValues(_._2) // Scala - * }}} - * - * @since 3.5.0 - */ - def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] = { - throw new UnsupportedOperationException - } + /** @inheritdoc */ + def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] = + unsupported() - /** - * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to - * the data. The grouping key is unchanged by this. - * - * {{{ - * // Create Integer values grouped by String key from a Dataset> - * Dataset> ds = ...; - * KeyValueGroupedDataset grouped = - * ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT()); - * }}} - * - * @since 3.5.0 - */ - def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = { - mapValues(UdfUtils.mapFunctionToScalaFunc(func))(encoder) - } - - /** - * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping over - * the Dataset to extract the keys and then running a distinct operation on those. - * - * @since 3.5.0 - */ - def keys: Dataset[K] = { - throw new UnsupportedOperationException - } + /** @inheritdoc */ + def keys: Dataset[K] = unsupported() - /** - * (Scala-specific) Applies the given function to each group of data. For each unique group, the - * function will be passed the group key and an iterator that contains all of the elements in - * the group. The function can return an iterator containing elements of an arbitrary type which - * will be returned as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the - * memory constraints of their cluster. - * - * @since 3.5.0 - */ - def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = { - flatMapSortedGroups()(f) - } - - /** - * (Java-specific) Applies the given function to each group of data. For each unique group, the - * function will be passed the group key and an iterator that contains all of the elements in - * the group. The function can return an iterator containing elements of an arbitrary type which - * will be returned as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the - * memory constraints of their cluster. - * - * @since 3.5.0 - */ - def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - flatMapGroups(UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder) - } - - /** - * (Scala-specific) Applies the given function to each group of data. For each unique group, the - * function will be passed the group key and a sorted iterator that contains all of the elements - * in the group. The function can return an iterator containing elements of an arbitrary type - * which will be returned as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the - * memory constraints of their cluster. - * - * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be - * sorted according to the given sort expressions. That sorting does not add computational - * complexity. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def flatMapSortedGroups[U: Encoder](sortExprs: Column*)( - f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = { - throw new UnsupportedOperationException - } - - /** - * (Java-specific) Applies the given function to each group of data. For each unique group, the - * function will be passed the group key and a sorted iterator that contains all of the elements - * in the group. The function can return an iterator containing elements of an arbitrary type - * which will be returned as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the - * memory constraints of their cluster. - * - * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be - * sorted according to the given sort expressions. That sorting does not add computational - * complexity. - * - * @since 3.5.0 - */ - def flatMapSortedGroups[U]( - SortExprs: Array[Column], - f: FlatMapGroupsFunction[K, V, U], - encoder: Encoder[U]): Dataset[U] = { - import org.apache.spark.util.ArrayImplicits._ - flatMapSortedGroups(SortExprs.toImmutableArraySeq: _*)( - UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder) - } - - /** - * (Scala-specific) Applies the given function to each group of data. For each unique group, the - * function will be passed the group key and an iterator that contains all of the elements in - * the group. The function can return an element of arbitrary type which will be returned as a - * new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the - * memory constraints of their cluster. - * - * @since 3.5.0 - */ - def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { - flatMapGroups(UdfUtils.mapGroupsFuncToFlatMapAdaptor(f)) - } - - /** - * (Java-specific) Applies the given function to each group of data. For each unique group, the - * function will be passed the group key and an iterator that contains all of the elements in - * the group. The function can return an element of arbitrary type which will be returned as a - * new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the - * memory constraints of their cluster. - * - * @since 3.5.0 - */ - def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - mapGroups(UdfUtils.mapGroupsFuncToScalaFunc(f))(encoder) - } - - /** - * (Scala-specific) Reduces the elements of each group of data using the specified binary - * function. The given function must be commutative and associative or the result may be - * non-deterministic. - * - * @since 3.5.0 - */ - def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = { - throw new UnsupportedOperationException - } - - /** - * (Java-specific) Reduces the elements of each group of data using the specified binary - * function. The given function must be commutative and associative or the result may be - * non-deterministic. - * - * @since 3.5.0 - */ - def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = { - reduceGroups(UdfUtils.mapReduceFuncToScalaFunc(f)) - } - - /** - * Internal helper function for building typed aggregations that return tuples. For simplicity - * and code reuse, we do this without the help of the type system and then use helper functions - * that cast appropriately for the user facing interface. - */ - protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - throw new UnsupportedOperationException - } - - /** - * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key and the - * result of computing this aggregation over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = - aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] - - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = - aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] + f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = + unsupported() - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2, U3]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = - aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] + /** @inheritdoc */ + def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = unsupported() - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2, U3, U4]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3], - col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = - aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] - - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2, U3, U4, U5]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3], - col4: TypedColumn[V, U4], - col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] = - aggUntyped(col1, col2, col3, col4, col5).asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5)]] + /** @inheritdoc */ + protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = unsupported() - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2, U3, U4, U5, U6]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3], - col4: TypedColumn[V, U4], - col5: TypedColumn[V, U5], - col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] = - aggUntyped(col1, col2, col3, col4, col5, col6) - .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6)]] - - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2, U3, U4, U5, U6, U7]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3], - col4: TypedColumn[V, U4], - col5: TypedColumn[V, U5], - col6: TypedColumn[V, U6], - col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] = - aggUntyped(col1, col2, col3, col4, col5, col6, col7) - .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7)]] - - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2, U3, U4, U5, U6, U7, U8]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3], - col4: TypedColumn[V, U4], - col5: TypedColumn[V, U5], - col6: TypedColumn[V, U6], - col7: TypedColumn[V, U7], - col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] = - aggUntyped(col1, col2, col3, col4, col5, col6, col7, col8) - .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)]] - - /** - * Returns a [[Dataset]] that contains a tuple with each key and the number of items present for - * that key. - * - * @since 3.5.0 - */ - def count(): Dataset[(K, Long)] = agg(functions.count("*")) - - /** - * (Scala-specific) Applies the given function to each cogrouped data. For each unique group, - * the function will be passed the grouping key and 2 iterators containing all elements in the - * group from [[Dataset]] `this` and `other`. The function can return an iterator containing - * elements of an arbitrary type which will be returned as a new [[Dataset]]. - * - * @since 3.5.0 - */ - def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])( - f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = { - cogroupSorted(other)()()(f) - } - - /** - * (Java-specific) Applies the given function to each cogrouped data. For each unique group, the - * function will be passed the grouping key and 2 iterators containing all elements in the group - * from [[Dataset]] `this` and `other`. The function can return an iterator containing elements - * of an arbitrary type which will be returned as a new [[Dataset]]. - * - * @since 3.5.0 - */ - def cogroup[U, R]( - other: KeyValueGroupedDataset[K, U], - f: CoGroupFunction[K, V, U, R], - encoder: Encoder[R]): Dataset[R] = { - cogroup(other)(UdfUtils.coGroupFunctionToScalaFunc(f))(encoder) - } - - /** - * (Scala-specific) Applies the given function to each sorted cogrouped data. For each unique - * group, the function will be passed the grouping key and 2 sorted iterators containing all - * elements in the group from [[Dataset]] `this` and `other`. The function can return an - * iterator containing elements of an arbitrary type which will be returned as a new - * [[Dataset]]. - * - * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be - * sorted according to the given sort expressions. That sorting does not add computational - * complexity. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(thisSortExprs: Column*)( - otherSortExprs: Column*)( - f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = { - throw new UnsupportedOperationException - } - - /** - * (Java-specific) Applies the given function to each sorted cogrouped data. For each unique - * group, the function will be passed the grouping key and 2 sorted iterators containing all - * elements in the group from [[Dataset]] `this` and `other`. The function can return an - * iterator containing elements of an arbitrary type which will be returned as a new - * [[Dataset]]. - * - * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be - * sorted according to the given sort expressions. That sorting does not add computational - * complexity. - * - * @since 3.5.0 - */ - def cogroupSorted[U, R]( - other: KeyValueGroupedDataset[K, U], - thisSortExprs: Array[Column], - otherSortExprs: Array[Column], - f: CoGroupFunction[K, V, U, R], - encoder: Encoder[R]): Dataset[R] = { - import org.apache.spark.util.ArrayImplicits._ - cogroupSorted(other)(thisSortExprs.toImmutableArraySeq: _*)( - otherSortExprs.toImmutableArraySeq: _*)(UdfUtils.coGroupFunctionToScalaFunc(f))(encoder) - } + otherSortExprs: Column*)(f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = + unsupported() protected[sql] def flatMapGroupsWithStateHelper[S: Encoder, U: Encoder]( outputMode: Option[OutputMode], timeoutConf: GroupStateTimeout, initialState: Option[KeyValueGroupedDataset[K, S]], isMapGroupWithState: Boolean)( - func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = { - throw new UnsupportedOperationException - } + func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = unsupported() - /** - * (Scala-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See - * [[org.apache.spark.sql.streaming.GroupState]] for more details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ + /** @inheritdoc */ def mapGroupsWithState[S: Encoder, U: Encoder]( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { mapGroupsWithState(GroupStateTimeout.NoTimeout)(func) } - /** - * (Scala-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See - * [[org.apache.spark.sql.streaming.GroupState]] for more details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ + /** @inheritdoc */ def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { flatMapGroupsWithStateHelper(None, timeoutConf, None, isMapGroupWithState = true)( - UdfUtils.mapGroupsWithStateFuncToFlatMapAdaptor(func)) + UDFAdaptors.mapGroupsWithStateToFlatMapWithState(func)) } - /** - * (Scala-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See - * [[org.apache.spark.sql.streaming.GroupState]] for more details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param timeoutConf - * Timeout Conf, see GroupStateTimeout for more details - * @param initialState - * The user provided state that will be initialized when the first batch of data is processed - * in the streaming query. The user defined function will be called on the state data even if - * there are no other values in the group. To convert a Dataset ds of type Dataset[(K, S)] to - * a KeyValueGroupedDataset[K, S] do {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ + /** @inheritdoc */ def mapGroupsWithState[S: Encoder, U: Encoder]( timeoutConf: GroupStateTimeout, initialState: KeyValueGroupedDataset[K, S])( @@ -549,134 +107,10 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { None, timeoutConf, Some(initialState), - isMapGroupWithState = true)(UdfUtils.mapGroupsWithStateFuncToFlatMapAdaptor(func)) + isMapGroupWithState = true)(UDFAdaptors.mapGroupsWithStateToFlatMapWithState(func)) } - /** - * (Java-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param stateEncoder - * Encoder for the state type. - * @param outputEncoder - * Encoder for the output type. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ - def mapGroupsWithState[S, U]( - func: MapGroupsWithStateFunction[K, V, S, U], - stateEncoder: Encoder[S], - outputEncoder: Encoder[U]): Dataset[U] = { - mapGroupsWithState[S, U](UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))( - stateEncoder, - outputEncoder) - } - - /** - * (Java-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param stateEncoder - * Encoder for the state type. - * @param outputEncoder - * Encoder for the output type. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ - def mapGroupsWithState[S, U]( - func: MapGroupsWithStateFunction[K, V, S, U], - stateEncoder: Encoder[S], - outputEncoder: Encoder[U], - timeoutConf: GroupStateTimeout): Dataset[U] = { - mapGroupsWithState[S, U](timeoutConf)(UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))( - stateEncoder, - outputEncoder) - } - - /** - * (Java-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param stateEncoder - * Encoder for the state type. - * @param outputEncoder - * Encoder for the output type. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * @param initialState - * The user provided state that will be initialized when the first batch of data is processed - * in the streaming query. The user defined function will be called on the state data even if - * there are no other values in the group. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ - def mapGroupsWithState[S, U]( - func: MapGroupsWithStateFunction[K, V, S, U], - stateEncoder: Encoder[S], - outputEncoder: Encoder[U], - timeoutConf: GroupStateTimeout, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { - mapGroupsWithState[S, U](timeoutConf, initialState)( - UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))(stateEncoder, outputEncoder) - } - - /** - * (Scala-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param outputMode - * The output mode of the function. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ + /** @inheritdoc */ def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, timeoutConf: GroupStateTimeout)( @@ -688,33 +122,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { isMapGroupWithState = false)(func) } - /** - * (Scala-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param outputMode - * The output mode of the function. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * @param initialState - * The user provided state that will be initialized when the first batch of data is processed - * in the streaming query. The user defined function will be called on the state data even if - * there are no other values in the group. To covert a Dataset `ds` of type of type - * `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use - * {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} See [[Encoder]] for more details on what - * types are encodable to Spark SQL. - * @since 3.5.0 - */ + /** @inheritdoc */ def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, timeoutConf: GroupStateTimeout, @@ -727,201 +135,244 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { isMapGroupWithState = false)(func) } - /** - * (Java-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param outputMode - * The output mode of the function. - * @param stateEncoder - * Encoder for the state type. - * @param outputEncoder - * Encoder for the output type. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ - def flatMapGroupsWithState[S, U]( + /** @inheritdoc */ + private[sql] def transformWithState[U: Encoder]( + statefulProcessor: StatefulProcessor[K, V, U], + timeMode: TimeMode, + outputMode: OutputMode): Dataset[U] = + unsupported() + + /** @inheritdoc */ + private[sql] def transformWithState[U: Encoder, S: Encoder]( + statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], + timeMode: TimeMode, + outputMode: OutputMode, + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = + unsupported() + + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder]( + statefulProcessor: StatefulProcessor[K, V, U], + eventTimeColumnName: String, + outputMode: OutputMode): Dataset[U] = unsupported() + + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder, S: Encoder]( + statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], + eventTimeColumnName: String, + outputMode: OutputMode, + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = unsupported() + + // Overrides... + /** @inheritdoc */ + override def mapValues[W]( + func: MapFunction[V, W], + encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = super.mapValues(func, encoder) + + /** @inheritdoc */ + override def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = + super.flatMapGroups(f) + + /** @inheritdoc */ + override def flatMapGroups[U]( + f: FlatMapGroupsFunction[K, V, U], + encoder: Encoder[U]): Dataset[U] = super.flatMapGroups(f, encoder) + + /** @inheritdoc */ + override def flatMapSortedGroups[U]( + SortExprs: Array[Column], + f: FlatMapGroupsFunction[K, V, U], + encoder: Encoder[U]): Dataset[U] = super.flatMapSortedGroups(SortExprs, f, encoder) + + /** @inheritdoc */ + override def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = super.mapGroups(f) + + /** @inheritdoc */ + override def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = + super.mapGroups(f, encoder) + + /** @inheritdoc */ + override def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U]): Dataset[U] = + super.mapGroupsWithState(func, stateEncoder, outputEncoder) + + /** @inheritdoc */ + override def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout): Dataset[U] = + super.mapGroupsWithState(func, stateEncoder, outputEncoder, timeoutConf) + + /** @inheritdoc */ + override def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout, + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = + super.mapGroupsWithState(func, stateEncoder, outputEncoder, timeoutConf, initialState) + + /** @inheritdoc */ + override def flatMapGroupsWithState[S, U]( func: FlatMapGroupsWithStateFunction[K, V, S, U], outputMode: OutputMode, stateEncoder: Encoder[S], outputEncoder: Encoder[U], - timeoutConf: GroupStateTimeout): Dataset[U] = { - val f = UdfUtils.flatMapGroupsWithStateFuncToScalaFunc(func) - flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder) - } + timeoutConf: GroupStateTimeout): Dataset[U] = + super.flatMapGroupsWithState(func, outputMode, stateEncoder, outputEncoder, timeoutConf) - /** - * (Java-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param outputMode - * The output mode of the function. - * @param stateEncoder - * Encoder for the state type. - * @param outputEncoder - * Encoder for the output type. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * @param initialState - * The user provided state that will be initialized when the first batch of data is processed - * in the streaming query. The user defined function will be called on the state data even if - * there are no other values in the group. To covert a Dataset `ds` of type of type - * `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use - * {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ - def flatMapGroupsWithState[S, U]( + /** @inheritdoc */ + override def flatMapGroupsWithState[S, U]( func: FlatMapGroupsWithStateFunction[K, V, S, U], outputMode: OutputMode, stateEncoder: Encoder[S], outputEncoder: Encoder[U], timeoutConf: GroupStateTimeout, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { - val f = UdfUtils.flatMapGroupsWithStateFuncToScalaFunc(func) - flatMapGroupsWithState[S, U](outputMode, timeoutConf, initialState)(f)( - stateEncoder, - outputEncoder) - } - - /** - * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state - * API v2. We allow the user to act on per-group set of input rows along with keyed state and - * the user can choose to output/return 0 or more rows. For a streaming dataframe, we will - * repeatedly invoke the interface methods for new rows in each trigger and the user's - * state/state variables will be stored persistently across invocations. Currently this operator - * is not supported with Spark Connect. - * - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param statefulProcessor - * Instance of statefulProcessor whose functions will be invoked by the operator. - * @param timeMode - * The time mode semantics of the stateful processor for timers and TTL. - * @param outputMode - * The output mode of the stateful processor. - */ - private[sql] def transformWithState[U: Encoder]( + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = super.flatMapGroupsWithState( + func, + outputMode, + stateEncoder, + outputEncoder, + timeoutConf, + initialState) + + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], timeMode: TimeMode, - outputMode: OutputMode): Dataset[U] = { - throw new UnsupportedOperationException - } + outputMode: OutputMode, + outputEncoder: Encoder[U]) = + super.transformWithState(statefulProcessor, timeMode, outputMode, outputEncoder) - /** - * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API - * v2. We allow the user to act on per-group set of input rows along with keyed state and the - * user can choose to output/return 0 or more rows. For a streaming dataframe, we will - * repeatedly invoke the interface methods for new rows in each trigger and the user's - * state/state variables will be stored persistently across invocations. Currently this operator - * is not supported with Spark Connect. - * - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param statefulProcessor - * Instance of statefulProcessor whose functions will be invoked by the operator. - * @param timeMode - * The time mode semantics of the stateful processor for timers and TTL. - * @param outputMode - * The output mode of the stateful processor. - * @param outputEncoder - * Encoder for the output type. - */ - private[sql] def transformWithState[U: Encoder]( + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], - timeMode: TimeMode, + eventTimeColumnName: String, outputMode: OutputMode, - outputEncoder: Encoder[U]): Dataset[U] = { - throw new UnsupportedOperationException - } + outputEncoder: Encoder[U]) = + super.transformWithState(statefulProcessor, eventTimeColumnName, outputMode, outputEncoder) - /** - * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state - * API v2. Functions as the function above, but with additional initial state. Currently this - * operator is not supported with Spark Connect. - * - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @tparam S - * The type of initial state objects. Must be encodable to Spark SQL types. - * @param statefulProcessor - * Instance of statefulProcessor whose functions will be invoked by the operator. - * @param timeMode - * The time mode semantics of the stateful processor for timers and TTL. - * @param outputMode - * The output mode of the stateful processor. - * @param initialState - * User provided initial state that will be used to initiate state for the query in the first - * batch. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - */ - private[sql] def transformWithState[U: Encoder, S: Encoder]( + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], timeMode: TimeMode, outputMode: OutputMode, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { - throw new UnsupportedOperationException - } - - /** - * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API - * v2. Functions as the function above, but with additional initial state. Currently this - * operator is not supported with Spark Connect. - * - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @tparam S - * The type of initial state objects. Must be encodable to Spark SQL types. - * @param statefulProcessor - * Instance of statefulProcessor whose functions will be invoked by the operator. - * @param timeMode - * The time mode semantics of the stateful processor for timers and TTL. - * @param outputMode - * The output mode of the stateful processor. - * @param initialState - * User provided initial state that will be used to initiate state for the query in the first - * batch. - * @param outputEncoder - * Encoder for the output type. - * @param initialStateEncoder - * Encoder for the initial state type. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - */ - private[sql] def transformWithState[U: Encoder, S: Encoder]( + initialState: KeyValueGroupedDataset[K, S], + outputEncoder: Encoder[U], + initialStateEncoder: Encoder[S]) = super.transformWithState( + statefulProcessor, + timeMode, + outputMode, + initialState, + outputEncoder, + initialStateEncoder) + + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], - timeMode: TimeMode, outputMode: OutputMode, initialState: KeyValueGroupedDataset[K, S], + eventTimeColumnName: String, outputEncoder: Encoder[U], - initialStateEncoder: Encoder[S]): Dataset[U] = { - throw new UnsupportedOperationException - } + initialStateEncoder: Encoder[S]) = super.transformWithState( + statefulProcessor, + outputMode, + initialState, + eventTimeColumnName, + outputEncoder, + initialStateEncoder) + + /** @inheritdoc */ + override def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = super.reduceGroups(f) + + /** @inheritdoc */ + override def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = super.agg(col1) + + /** @inheritdoc */ + override def agg[U1, U2]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = super.agg(col1, col2) + + /** @inheritdoc */ + override def agg[U1, U2, U3]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = super.agg(col1, col2, col3) + + /** @inheritdoc */ + override def agg[U1, U2, U3, U4]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = super.agg(col1, col2, col3, col4) + + /** @inheritdoc */ + override def agg[U1, U2, U3, U4, U5]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] = + super.agg(col1, col2, col3, col4, col5) + + /** @inheritdoc */ + override def agg[U1, U2, U3, U4, U5, U6]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] = + super.agg(col1, col2, col3, col4, col5, col6) + + /** @inheritdoc */ + override def agg[U1, U2, U3, U4, U5, U6, U7]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6], + col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] = + super.agg(col1, col2, col3, col4, col5, col6, col7) + + /** @inheritdoc */ + override def agg[U1, U2, U3, U4, U5, U6, U7, U8]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6], + col7: TypedColumn[V, U7], + col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] = + super.agg(col1, col2, col3, col4, col5, col6, col7, col8) + + /** @inheritdoc */ + override def count(): Dataset[(K, Long)] = super.count() + + /** @inheritdoc */ + override def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])( + f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = + super.cogroup(other)(f) + + /** @inheritdoc */ + override def cogroup[U, R]( + other: KeyValueGroupedDataset[K, U], + f: CoGroupFunction[K, V, U, R], + encoder: Encoder[R]): Dataset[R] = super.cogroup(other, f, encoder) + + /** @inheritdoc */ + override def cogroupSorted[U, R]( + other: KeyValueGroupedDataset[K, U], + thisSortExprs: Array[Column], + otherSortExprs: Array[Column], + f: CoGroupFunction[K, V, U, R], + encoder: Encoder[R]): Dataset[R] = + super.cogroupSorted(other, thisSortExprs, otherSortExprs, f, encoder) } /** @@ -934,12 +385,11 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( private val sparkSession: SparkSession, private val plan: proto.Plan, - private val ikEncoder: AgnosticEncoder[IK], private val kEncoder: AgnosticEncoder[K], private val ivEncoder: AgnosticEncoder[IV], private val vEncoder: AgnosticEncoder[V], private val groupingExprs: java.util.List[proto.Expression], - private val valueMapFunc: IV => V, + private val valueMapFunc: Option[IV => V], private val keysFunc: () => Dataset[IK]) extends KeyValueGroupedDataset[K, V] { import sparkSession.RichColumn @@ -948,7 +398,6 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( new KeyValueGroupedDatasetImpl[L, V, IK, IV]( sparkSession, plan, - ikEncoder, encoderFor[L], ivEncoder, vEncoder, @@ -961,12 +410,13 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( new KeyValueGroupedDatasetImpl[K, W, IK, IV]( sparkSession, plan, - ikEncoder, kEncoder, ivEncoder, encoderFor[W], groupingExprs, - valueMapFunc.andThen(valueFunc), + valueMapFunc + .map(_.andThen(valueFunc)) + .orElse(Option(valueFunc.asInstanceOf[IV => W])), keysFunc) } @@ -979,8 +429,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( override def flatMapSortedGroups[U: Encoder](sortExprs: Column*)( f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = { // Apply mapValues changes to the udf - val nf = - if (valueMapFunc == UdfUtils.identical()) f else UdfUtils.mapValuesAdaptor(f, valueMapFunc) + val nf = UDFAdaptors.flatMapGroupsWithMappedValues(f, valueMapFunc) val outputEncoder = encoderFor[U] sparkSession.newDataset[U](outputEncoder) { builder => builder.getGroupMapBuilder @@ -994,10 +443,9 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( override def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, U])( thisSortExprs: Column*)(otherSortExprs: Column*)( f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = { - assert(other.isInstanceOf[KeyValueGroupedDatasetImpl[K, U, _, _]]) - val otherImpl = other.asInstanceOf[KeyValueGroupedDatasetImpl[K, U, _, _]] + val otherImpl = other.asInstanceOf[KeyValueGroupedDatasetImpl[K, U, _, Any]] // Apply mapValues changes to the udf - val nf = UdfUtils.mapValuesAdaptor(f, valueMapFunc, otherImpl.valueMapFunc) + val nf = UDFAdaptors.coGroupWithMappedValues(f, valueMapFunc, otherImpl.valueMapFunc) val outputEncoder = encoderFor[R] sparkSession.newDataset[R](outputEncoder) { builder => builder.getCoGroupMapBuilder @@ -1012,8 +460,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( } override protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - // TODO(SPARK-43415): For each column, apply the valueMap func first - // apply keyAs change + // TODO(SPARK-43415): For each column, apply the valueMap func first... val rEnc = ProductEncoder.tuple(kEncoder +: columns.map(c => encoderFor(c.encoder))) sparkSession.newDataset(rEnc) { builder => builder.getAggregateBuilder @@ -1047,22 +494,15 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( throw new IllegalArgumentException("The output mode of function should be append or update") } - if (initialState.isDefined) { - assert(initialState.get.isInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]]) - } - val initialStateImpl = if (initialState.isDefined) { + assert(initialState.get.isInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]]) initialState.get.asInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]] } else { null } val outputEncoder = encoderFor[U] - val nf = if (valueMapFunc == UdfUtils.identical()) { - func - } else { - UdfUtils.mapValuesAdaptor(func, valueMapFunc) - } + val nf = UDFAdaptors.flatMapGroupsWithStateWithMappedValues(func, valueMapFunc) sparkSession.newDataset[U](outputEncoder) { builder => val groupMapBuilder = builder.getGroupMapBuilder @@ -1097,6 +537,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( * We cannot deserialize a connect [[KeyValueGroupedDataset]] because of a class clash on the * server side. We null out the instance for now. */ + @unused("this is used by java serialization") private def writeReplace(): Any = null } @@ -1114,11 +555,10 @@ private object KeyValueGroupedDatasetImpl { session, ds.plan, kEncoder, - kEncoder, ds.agnosticEncoder, ds.agnosticEncoder, Arrays.asList(toExpr(gf.apply(col("*")))), - UdfUtils.identical(), + None, () => ds.map(groupingFunc)(kEncoder)) } @@ -1137,11 +577,10 @@ private object KeyValueGroupedDatasetImpl { session, df.plan, kEncoder, - kEncoder, vEncoder, vEncoder, (Seq(dummyGroupingFunc) ++ groupingExprs).map(toExpr).asJava, - UdfUtils.identical(), + None, () => df.select(groupingExprs: _*).as(kEncoder)) } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index c9b011ca4535b..14ceb3f4bb144 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto +import org.apache.spark.sql.connect.ConnectConversions._ /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -39,8 +40,7 @@ class RelationalGroupedDataset private[sql] ( groupType: proto.Aggregate.GroupType, pivot: Option[proto.Aggregate.Pivot] = None, groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None) - extends api.RelationalGroupedDataset[Dataset] { - type RGD = RelationalGroupedDataset + extends api.RelationalGroupedDataset { import df.sparkSession.RichColumn protected def toDF(aggExprs: Seq[Column]): DataFrame = { @@ -80,12 +80,7 @@ class RelationalGroupedDataset private[sql] ( colNames.map(df.col) } - /** - * Returns a `KeyValueGroupedDataset` where the data is grouped by the grouping expressions of - * current `RelationalGroupedDataset`. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = { KeyValueGroupedDatasetImpl[K, T](df, encoderFor[K], encoderFor[T], groupingExprs) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 24d0a5ac7262f..04f8eeb5c6d46 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql import java.net.URI +import java.nio.file.{Files, Paths} import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.{AtomicLong, AtomicReference} import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag +import scala.util.Try import com.google.common.cache.{CacheBuilder, CacheLoader} import io.grpc.ClientInterceptor @@ -39,7 +41,7 @@ import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, Spar import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.internal.{CatalogImpl, SessionCleaner, SqlApiConf} +import org.apache.spark.sql.internal.{CatalogImpl, ConnectRuntimeConfig, SessionCleaner, SqlApiConf} import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.{toExpr, toTypedExpr} import org.apache.spark.sql.streaming.DataStreamReader import org.apache.spark.sql.streaming.StreamingQueryManager @@ -67,7 +69,7 @@ import org.apache.spark.util.ArrayImplicits._ class SparkSession private[sql] ( private[sql] val client: SparkConnectClient, private val planIdGenerator: AtomicLong) - extends api.SparkSession[Dataset] + extends api.SparkSession with Logging { private[this] val allocator = new RootAllocator() @@ -86,16 +88,8 @@ class SparkSession private[sql] ( client.hijackServerSideSessionIdForTesting(suffix) } - /** - * Runtime configuration interface for Spark. - * - * This is the interface through which the user can get and set all Spark configurations that - * are relevant to Spark SQL. When getting the value of a config, his defaults to the value set - * in server, if any. - * - * @since 3.4.0 - */ - val conf: RuntimeConfig = new RuntimeConfig(client) + /** @inheritdoc */ + val conf: RuntimeConfig = new ConnectRuntimeConfig(client) /** @inheritdoc */ @transient @@ -212,16 +206,7 @@ class SparkSession private[sql] ( sql(query, Array.empty) } - /** - * Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a - * `DataFrame`. - * {{{ - * sparkSession.read.parquet("/path/to/file.parquet") - * sparkSession.read.schema(schema).json("/path/to/file.json") - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def read: DataFrameReader = new DataFrameReader(this) /** @@ -237,12 +222,7 @@ class SparkSession private[sql] ( lazy val streams: StreamingQueryManager = new StreamingQueryManager(this) - /** - * Interface through which the user may create, drop, alter or query underlying databases, - * tables, functions etc. - * - * @since 3.5.0 - */ + /** @inheritdoc */ lazy val catalog: Catalog = new CatalogImpl(this) /** @inheritdoc */ @@ -440,7 +420,7 @@ class SparkSession private[sql] ( * * @since 3.5.0 */ - def interruptAll(): Seq[String] = { + override def interruptAll(): Seq[String] = { client.interruptAll().getInterruptedIdsList.asScala.toSeq } @@ -453,7 +433,7 @@ class SparkSession private[sql] ( * * @since 3.5.0 */ - def interruptTag(tag: String): Seq[String] = { + override def interruptTag(tag: String): Seq[String] = { client.interruptTag(tag).getInterruptedIdsList.asScala.toSeq } @@ -466,7 +446,7 @@ class SparkSession private[sql] ( * * @since 3.5.0 */ - def interruptOperation(operationId: String): Seq[String] = { + override def interruptOperation(operationId: String): Seq[String] = { client.interruptOperation(operationId).getInterruptedIdsList.asScala.toSeq } @@ -497,65 +477,17 @@ class SparkSession private[sql] ( SparkSession.onSessionClose(this) } - /** - * Add a tag to be assigned to all the operations started by this thread in this session. - * - * Often, a unit of execution in an application consists of multiple Spark executions. - * Application programmers can use this method to group all those jobs together and give a group - * tag. The application can use `org.apache.spark.sql.SparkSession.interruptTag` to cancel all - * running running executions with this tag. For example: - * {{{ - * // In the main thread: - * spark.addTag("myjobs") - * spark.range(10).map(i => { Thread.sleep(10); i }).collect() - * - * // In a separate thread: - * spark.interruptTag("myjobs") - * }}} - * - * There may be multiple tags present at the same time, so different parts of application may - * use different tags to perform cancellation at different levels of granularity. - * - * @param tag - * The tag to be added. Cannot contain ',' (comma) character or be an empty string. - * - * @since 3.5.0 - */ - def addTag(tag: String): Unit = { - client.addTag(tag) - } + /** @inheritdoc */ + override def addTag(tag: String): Unit = client.addTag(tag) - /** - * Remove a tag previously added to be assigned to all the operations started by this thread in - * this session. Noop if such a tag was not added earlier. - * - * @param tag - * The tag to be removed. Cannot contain ',' (comma) character or be an empty string. - * - * @since 3.5.0 - */ - def removeTag(tag: String): Unit = { - client.removeTag(tag) - } + /** @inheritdoc */ + override def removeTag(tag: String): Unit = client.removeTag(tag) - /** - * Get the tags that are currently set to be assigned to all the operations started by this - * thread. - * - * @since 3.5.0 - */ - def getTags(): Set[String] = { - client.getTags() - } + /** @inheritdoc */ + override def getTags(): Set[String] = client.getTags() - /** - * Clear the current thread's operation tags. - * - * @since 3.5.0 - */ - def clearTags(): Unit = { - client.clearTags() - } + /** @inheritdoc */ + override def clearTags(): Unit = client.clearTags() /** * We cannot deserialize a connect [[SparkSession]] because of a class clash on the server side. @@ -591,6 +523,10 @@ class SparkSession private[sql] ( object SparkSession extends Logging { private val MAX_CACHED_SESSIONS = 100 private val planIdGenerator = new AtomicLong + private var server: Option[Process] = None + private[sql] val sparkOptions = sys.props.filter { p => + p._1.startsWith("spark.") && p._2.nonEmpty + }.toMap private val sessions = CacheBuilder .newBuilder() @@ -623,6 +559,51 @@ object SparkSession extends Logging { } } + /** + * Create a new Spark Connect server to connect locally. + */ + private[sql] def withLocalConnectServer[T](f: => T): T = { + synchronized { + val remoteString = sparkOptions + .get("spark.remote") + .orElse(Option(System.getProperty("spark.remote"))) // Set from Spark Submit + .orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE)) + + val maybeConnectScript = + Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh")) + + if (server.isEmpty && + remoteString.exists(_.startsWith("local")) && + maybeConnectScript.exists(Files.exists(_))) { + server = Some { + val args = + Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions + .filter(p => !p._1.startsWith("spark.remote")) + .flatMap { case (k, v) => Seq("--conf", s"$k=$v") } + val pb = new ProcessBuilder(args: _*) + // So don't exclude spark-sql jar in classpath + pb.environment().remove(SparkConnectClient.SPARK_REMOTE) + pb.start() + } + + // Let the server start. We will directly request to set the configurations + // and this sleep makes less noisy with retries. + Thread.sleep(2000L) + System.setProperty("spark.remote", "sc://localhost") + + // scalastyle:off runtimeaddshutdownhook + Runtime.getRuntime.addShutdownHook(new Thread() { + override def run(): Unit = if (server.isDefined) { + new ProcessBuilder(maybeConnectScript.get.toString) + .start() + } + }) + // scalastyle:on runtimeaddshutdownhook + } + } + f + } + /** * Create a new [[SparkSession]] based on the connect client [[Configuration]]. */ @@ -765,6 +746,16 @@ object SparkSession extends Logging { } private def applyOptions(session: SparkSession): Unit = { + // Only attempts to set Spark SQL configurations. + // If the configurations are static, it might throw an exception so + // simply ignore it for now. + sparkOptions + .filter { case (k, _) => + k.startsWith("spark.sql.") + } + .foreach { case (key, value) => + Try(session.conf.set(key, value)) + } options.foreach { case (key, value) => session.conf.set(key, value) } @@ -787,7 +778,7 @@ object SparkSession extends Logging { * * @since 3.5.0 */ - def create(): SparkSession = { + def create(): SparkSession = withLocalConnectServer { val session = tryCreateSessionFromClient() .getOrElse(SparkSession.this.create(builder.configuration)) setDefaultAndActiveSession(session) @@ -807,7 +798,7 @@ object SparkSession extends Logging { * * @since 3.5.0 */ - def getOrCreate(): SparkSession = { + def getOrCreate(): SparkSession = withLocalConnectServer { val session = tryCreateSessionFromClient() .getOrElse({ var existingSession = sessions.get(builder.configuration) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala index 86775803a0937..63fa2821a6c6a 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.application import java.io.{InputStream, OutputStream} -import java.nio.file.Paths import java.util.concurrent.Semaphore -import scala.util.Try import scala.util.control.NonFatal import ammonite.compiler.CodeClassWrapper @@ -34,6 +32,7 @@ import ammonite.util.Util.newLine import org.apache.spark.SparkBuildInfo.spark_version import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.SparkSession.withLocalConnectServer import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkConnectClientParser} /** @@ -64,37 +63,7 @@ Spark session available as 'spark'. semaphore: Option[Semaphore] = None, inputStream: InputStream = System.in, outputStream: OutputStream = System.out, - errorStream: OutputStream = System.err): Unit = { - val configs: Map[String, String] = - sys.props - .filter(p => - p._1.startsWith("spark.") && - p._2.nonEmpty && - // Don't include spark.remote that we manually set later. - !p._1.startsWith("spark.remote")) - .toMap - - val remoteString: Option[String] = - Option(System.getProperty("spark.remote")) // Set from Spark Submit - .orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE)) - - if (remoteString.exists(_.startsWith("local"))) { - server = Some { - val args = Seq( - Paths.get(sparkHome, "sbin", "start-connect-server.sh").toString, - "--master", - remoteString.get) ++ configs.flatMap { case (k, v) => Seq("--conf", s"$k=$v") } - val pb = new ProcessBuilder(args: _*) - // So don't exclude spark-sql jar in classpath - pb.environment().remove(SparkConnectClient.SPARK_REMOTE) - pb.start() - } - // Let the server start. We will directly request to set the configurations - // and this sleep makes less noisy with retries. - Thread.sleep(2000L) - System.setProperty("spark.remote", "sc://localhost") - } - + errorStream: OutputStream = System.err): Unit = withLocalConnectServer { // Build the client. val client = try { @@ -118,13 +87,6 @@ Spark session available as 'spark'. // Build the session. val spark = SparkSession.builder().client(client).getOrCreate() - - // The configurations might not be all runtime configurations. - // Try to set them with ignoring failures for now. - configs - .filter(_._1.startsWith("spark.sql")) - .foreach { case (k, v) => Try(spark.conf.set(k, v)) } - val sparkBind = new Bind("spark", spark) // Add the proper imports and register a [[ClassFinder]]. @@ -197,18 +159,12 @@ Spark session available as 'spark'. } } } - try { - if (semaphore.nonEmpty) { - // Used for testing. - main.run(sparkBind, new Bind[Semaphore]("semaphore", semaphore.get)) - } else { - main.run(sparkBind) - } - } finally { - if (server.isDefined) { - new ProcessBuilder(Paths.get(sparkHome, "sbin", "stop-connect-server.sh").toString) - .start() - } + + if (semaphore.nonEmpty) { + // Used for testing. + main.run(sparkBind, new Bind[Semaphore]("semaphore", semaphore.get)) + } else { + main.run(sparkBind) } } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index cf0fef147ee84..86b1dbe4754e6 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -17,660 +17,152 @@ package org.apache.spark.sql.catalog -import scala.jdk.CollectionConverters._ +import java.util -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset} +import org.apache.spark.sql.{api, DataFrame, Dataset} +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.types.StructType -import org.apache.spark.storage.StorageLevel -/** - * Catalog interface for Spark. To access this, use `SparkSession.catalog`. - * - * @since 3.5.0 - */ -abstract class Catalog { - - /** - * Returns the current database (namespace) in this session. - * - * @since 3.5.0 - */ - def currentDatabase: String - - /** - * Sets the current database (namespace) in this session. - * - * @since 3.5.0 - */ - def setCurrentDatabase(dbName: String): Unit - - /** - * Returns a list of databases (namespaces) available within the current catalog. - * - * @since 3.5.0 - */ - def listDatabases(): Dataset[Database] - - /** - * Returns a list of databases (namespaces) which name match the specify pattern and available - * within the current catalog. - * - * @since 3.5.0 - */ - def listDatabases(pattern: String): Dataset[Database] - - /** - * Returns a list of tables/views in the current database (namespace). This includes all - * temporary views. - * - * @since 3.5.0 - */ - def listTables(): Dataset[Table] - - /** - * Returns a list of tables/views in the specified database (namespace) (the name can be - * qualified with catalog). This includes all temporary views. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database does not exist") - def listTables(dbName: String): Dataset[Table] - - /** - * Returns a list of tables/views in the specified database (namespace) which name match the - * specify pattern (the name can be qualified with catalog). This includes all temporary views. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database does not exist") - def listTables(dbName: String, pattern: String): Dataset[Table] - - /** - * Returns a list of functions registered in the current database (namespace). This includes all - * temporary functions. - * - * @since 3.5.0 - */ - def listFunctions(): Dataset[Function] - - /** - * Returns a list of functions registered in the specified database (namespace) (the name can be - * qualified with catalog). This includes all built-in and temporary functions. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database does not exist") - def listFunctions(dbName: String): Dataset[Function] - - /** - * Returns a list of functions registered in the specified database (namespace) which name match - * the specify pattern (the name can be qualified with catalog). This includes all built-in and - * temporary functions. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database does not exist") - def listFunctions(dbName: String, pattern: String): Dataset[Function] - - /** - * Returns a list of columns for the given table/view or temporary view. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. It follows the same - * resolution rule with SQL: search for temp views first then table/views in the current - * database (namespace). - * @since 3.5.0 - */ - @throws[AnalysisException]("table does not exist") - def listColumns(tableName: String): Dataset[Column] - - /** - * Returns a list of columns for the given table/view in the specified database under the Hive - * Metastore. - * - * To list columns for table/view in other catalogs, please use `listColumns(tableName)` with - * qualified table/view name instead. - * - * @param dbName - * is an unqualified name that designates a database. - * @param tableName - * is an unqualified name that designates a table/view. - * @since 3.5.0 - */ - @throws[AnalysisException]("database or table does not exist") - def listColumns(dbName: String, tableName: String): Dataset[Column] - - /** - * Get the database (namespace) with the specified name (can be qualified with catalog). This - * throws an AnalysisException when the database (namespace) cannot be found. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database does not exist") - def getDatabase(dbName: String): Database - - /** - * Get the table or view with the specified name. This table can be a temporary view or a - * table/view. This throws an AnalysisException when no Table can be found. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. It follows the same - * resolution rule with SQL: search for temp views first then table/views in the current - * database (namespace). - * @since 3.5.0 - */ - @throws[AnalysisException]("table does not exist") - def getTable(tableName: String): Table - - /** - * Get the table or view with the specified name in the specified database under the Hive - * Metastore. This throws an AnalysisException when no Table can be found. - * - * To get table/view in other catalogs, please use `getTable(tableName)` with qualified - * table/view name instead. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database or table does not exist") - def getTable(dbName: String, tableName: String): Table - - /** - * Get the function with the specified name. This function can be a temporary function or a - * function. This throws an AnalysisException when the function cannot be found. - * - * @param functionName - * is either a qualified or unqualified name that designates a function. It follows the same - * resolution rule with SQL: search for built-in/temp functions first then functions in the - * current database (namespace). - * @since 3.5.0 - */ - @throws[AnalysisException]("function does not exist") - def getFunction(functionName: String): Function - - /** - * Get the function with the specified name in the specified database under the Hive Metastore. - * This throws an AnalysisException when the function cannot be found. - * - * To get functions in other catalogs, please use `getFunction(functionName)` with qualified - * function name instead. - * - * @param dbName - * is an unqualified name that designates a database. - * @param functionName - * is an unqualified name that designates a function in the specified database - * @since 3.5.0 - */ - @throws[AnalysisException]("database or function does not exist") - def getFunction(dbName: String, functionName: String): Function - - /** - * Check if the database (namespace) with the specified name exists (the name can be qualified - * with catalog). - * - * @since 3.5.0 - */ - def databaseExists(dbName: String): Boolean - - /** - * Check if the table or view with the specified name exists. This can either be a temporary - * view or a table/view. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. It follows the same - * resolution rule with SQL: search for temp views first then table/views in the current - * database (namespace). - * @since 3.5.0 - */ - def tableExists(tableName: String): Boolean - - /** - * Check if the table or view with the specified name exists in the specified database under the - * Hive Metastore. - * - * To check existence of table/view in other catalogs, please use `tableExists(tableName)` with - * qualified table/view name instead. - * - * @param dbName - * is an unqualified name that designates a database. - * @param tableName - * is an unqualified name that designates a table. - * @since 3.5.0 - */ - def tableExists(dbName: String, tableName: String): Boolean - - /** - * Check if the function with the specified name exists. This can either be a temporary function - * or a function. - * - * @param functionName - * is either a qualified or unqualified name that designates a function. It follows the same - * resolution rule with SQL: search for built-in/temp functions first then functions in the - * current database (namespace). - * @since 3.5.0 - */ - def functionExists(functionName: String): Boolean - - /** - * Check if the function with the specified name exists in the specified database under the Hive - * Metastore. - * - * To check existence of functions in other catalogs, please use `functionExists(functionName)` - * with qualified function name instead. - * - * @param dbName - * is an unqualified name that designates a database. - * @param functionName - * is an unqualified name that designates a function. - * @since 3.5.0 - */ - def functionExists(dbName: String, functionName: String): Boolean - - /** - * Creates a table from the given path and returns the corresponding DataFrame. It will use the - * default data source configured by spark.sql.sources.default. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable(tableName: String, path: String): DataFrame = { - createTable(tableName, path) - } - - /** - * Creates a table from the given path and returns the corresponding DataFrame. It will use the - * default data source configured by spark.sql.sources.default. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable(tableName: String, path: String): DataFrame - - /** - * Creates a table from the given path based on a data source and returns the corresponding - * DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable(tableName: String, path: String, source: String): DataFrame = { - createTable(tableName, path, source) - } - - /** - * Creates a table from the given path based on a data source and returns the corresponding - * DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable(tableName: String, path: String, source: String): DataFrame - - /** - * Creates a table from the given path based on a data source and a set of options. Then, - * returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable( +/** @inheritdoc */ +abstract class Catalog extends api.Catalog { + + /** @inheritdoc */ + override def listDatabases(): Dataset[Database] + + /** @inheritdoc */ + override def listDatabases(pattern: String): Dataset[Database] + + /** @inheritdoc */ + override def listTables(): Dataset[Table] + + /** @inheritdoc */ + override def listTables(dbName: String): Dataset[Table] + + /** @inheritdoc */ + override def listTables(dbName: String, pattern: String): Dataset[Table] + + /** @inheritdoc */ + override def listFunctions(): Dataset[Function] + + /** @inheritdoc */ + override def listFunctions(dbName: String): Dataset[Function] + + /** @inheritdoc */ + override def listFunctions(dbName: String, pattern: String): Dataset[Function] + + /** @inheritdoc */ + override def listColumns(tableName: String): Dataset[Column] + + /** @inheritdoc */ + override def listColumns(dbName: String, tableName: String): Dataset[Column] + + /** @inheritdoc */ + override def createTable(tableName: String, path: String): DataFrame + + /** @inheritdoc */ + override def createTable(tableName: String, path: String, source: String): DataFrame + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, - options: java.util.Map[String, String]): DataFrame = { - createTable(tableName, source, options) - } - - /** - * Creates a table based on the dataset in a data source and a set of options. Then, returns the - * corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + options: Map[String, String]): DataFrame + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, - options: java.util.Map[String, String]): DataFrame = { - createTable(tableName, source, options.asScala.toMap) - } - - /** - * (Scala-specific) Creates a table from the given path based on a data source and a set of - * options. Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable( + description: String, + options: Map[String, String]): DataFrame + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, - options: Map[String, String]): DataFrame = { - createTable(tableName, source, options) - } - - /** - * (Scala-specific) Creates a table based on the dataset in a data source and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable(tableName: String, source: String, options: Map[String, String]): DataFrame - - /** - * Create a table from the given path based on a data source, a schema and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable( + schema: StructType, + options: Map[String, String]): DataFrame + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - createTable(tableName, source, schema, options) - } - - /** - * Creates a table based on the dataset in a data source and a set of options. Then, returns the - * corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + description: String, + options: Map[String, String]): DataFrame + + /** @inheritdoc */ + override def listCatalogs(): Dataset[CatalogMetadata] + + /** @inheritdoc */ + override def listCatalogs(pattern: String): Dataset[CatalogMetadata] + + /** @inheritdoc */ + override def createExternalTable(tableName: String, path: String): DataFrame = + super.createExternalTable(tableName, path) + + /** @inheritdoc */ + override def createExternalTable(tableName: String, path: String, source: String): DataFrame = + super.createExternalTable(tableName, path, source) + + /** @inheritdoc */ + override def createExternalTable( tableName: String, source: String, - description: String, - options: java.util.Map[String, String]): DataFrame = { - createTable( - tableName, - source = source, - description = description, - options = options.asScala.toMap) - } - - /** - * (Scala-specific) Creates a table based on the dataset in a data source and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + options: util.Map[String, String]): DataFrame = + super.createExternalTable(tableName, source, options) + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, - description: String, - options: Map[String, String]): DataFrame + options: util.Map[String, String]): DataFrame = + super.createTable(tableName, source, options) - /** - * Create a table based on the dataset in a data source, a schema and a set of options. Then, - * returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + /** @inheritdoc */ + override def createExternalTable( tableName: String, source: String, - schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - createTable(tableName, source, schema, options.asScala.toMap) - } - - /** - * (Scala-specific) Create a table from the given path based on a data source, a schema and a - * set of options. Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable( + options: Map[String, String]): DataFrame = + super.createExternalTable(tableName, source, options) + + /** @inheritdoc */ + override def createExternalTable( tableName: String, source: String, schema: StructType, - options: Map[String, String]): DataFrame = { - createTable(tableName, source, schema, options) - } - - /** - * (Scala-specific) Create a table based on the dataset in a data source, a schema and a set of - * options. Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + options: util.Map[String, String]): DataFrame = + super.createExternalTable(tableName, source, schema, options) + + /** @inheritdoc */ + override def createTable( + tableName: String, + source: String, + description: String, + options: util.Map[String, String]): DataFrame = + super.createTable(tableName, source, description, options) + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, schema: StructType, - options: Map[String, String]): DataFrame + options: util.Map[String, String]): DataFrame = + super.createTable(tableName, source, schema, options) - /** - * Create a table based on the dataset in a data source, a schema and a set of options. Then, - * returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + /** @inheritdoc */ + override def createExternalTable( tableName: String, source: String, schema: StructType, - description: String, - options: java.util.Map[String, String]): DataFrame = { - createTable( - tableName, - source = source, - schema = schema, - description = description, - options = options.asScala.toMap) - } - - /** - * (Scala-specific) Create a table based on the dataset in a data source, a schema and a set of - * options. Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + options: Map[String, String]): DataFrame = + super.createExternalTable(tableName, source, schema, options) + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, schema: StructType, description: String, - options: Map[String, String]): DataFrame - - /** - * Drops the local temporary view with the given view name in the catalog. If the view has been - * cached before, then it will also be uncached. - * - * Local temporary view is session-scoped. Its lifetime is the lifetime of the session that - * created it, i.e. it will be automatically dropped when the session terminates. It's not tied - * to any databases, i.e. we can't use `db1.view1` to reference a local temporary view. - * - * Note that, the return type of this method was Unit in Spark 2.0, but changed to Boolean in - * Spark 2.1. - * - * @param viewName - * the name of the temporary view to be dropped. - * @return - * true if the view is dropped successfully, false otherwise. - * @since 3.5.0 - */ - def dropTempView(viewName: String): Boolean - - /** - * Drops the global temporary view with the given view name in the catalog. If the view has been - * cached before, then it will also be uncached. - * - * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark - * application, i.e. it will be automatically dropped when the application terminates. It's tied - * to a system preserved database `global_temp`, and we must use the qualified name to refer a - * global temp view, e.g. `SELECT * FROM global_temp.view1`. - * - * @param viewName - * the unqualified name of the temporary view to be dropped. - * @return - * true if the view is dropped successfully, false otherwise. - * @since 3.5.0 - */ - def dropGlobalTempView(viewName: String): Boolean - - /** - * Recovers all the partitions in the directory of a table and update the catalog. Only works - * with a partitioned table, and not a view. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def recoverPartitions(tableName: String): Unit - - /** - * Returns true if the table is currently cached in-memory. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. If no database - * identifier is provided, it refers to a temporary view or a table/view in the current - * database. - * @since 3.5.0 - */ - def isCached(tableName: String): Boolean - - /** - * Caches the specified table in-memory. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. If no database - * identifier is provided, it refers to a temporary view or a table/view in the current - * database. - * @since 3.5.0 - */ - def cacheTable(tableName: String): Unit - - /** - * Caches the specified table with the given storage level. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. If no database - * identifier is provided, it refers to a temporary view or a table/view in the current - * database. - * @param storageLevel - * storage level to cache table. - * @since 3.5.0 - */ - def cacheTable(tableName: String, storageLevel: StorageLevel): Unit - - /** - * Removes the specified table from the in-memory cache. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. If no database - * identifier is provided, it refers to a temporary view or a table/view in the current - * database. - * @since 3.5.0 - */ - def uncacheTable(tableName: String): Unit - - /** - * Removes all cached tables from the in-memory cache. - * - * @since 3.5.0 - */ - def clearCache(): Unit - - /** - * Invalidates and refreshes all the cached data and metadata of the given table. For - * performance reasons, Spark SQL or the external data source library it uses might cache - * certain metadata about a table, such as the location of blocks. When those change outside of - * Spark SQL, users should call this function to invalidate the cache. - * - * If this table is cached as an InMemoryRelation, drop the original cached version and make the - * new version cached lazily. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. If no database - * identifier is provided, it refers to a temporary view or a table/view in the current - * database. - * @since 3.5.0 - */ - def refreshTable(tableName: String): Unit - - /** - * Invalidates and refreshes all the cached data (and the associated metadata) for any `Dataset` - * that contains the given data source path. Path matching is by prefix, i.e. "/" would - * invalidate everything that is cached. - * - * @since 3.5.0 - */ - def refreshByPath(path: String): Unit - - /** - * Returns the current catalog in this session. - * - * @since 3.5.0 - */ - def currentCatalog(): String - - /** - * Sets the current catalog in this session. - * - * @since 3.5.0 - */ - def setCurrentCatalog(catalogName: String): Unit - - /** - * Returns a list of catalogs available in this session. - * - * @since 3.5.0 - */ - def listCatalogs(): Dataset[CatalogMetadata] - - /** - * Returns a list of catalogs which name match the specify pattern and available in this - * session. - * - * @since 3.5.0 - */ - def listCatalogs(pattern: String): Dataset[CatalogMetadata] + options: util.Map[String, String]): DataFrame = + super.createTable(tableName, source, schema, description, options) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala new file mode 100644 index 0000000000000..7d81f4ead7857 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect + +import scala.language.implicitConversions + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql._ + +/** + * Conversions from sql interfaces to the Connect specific implementation. + * + * This class is mainly used by the implementation. In the case of connect it should be extremely + * rare that a developer needs these classes. + * + * We provide both a trait and an object. The trait is useful in situations where an extension + * developer needs to use these conversions in a project covering multiple Spark versions. They + * can create a shim for these conversions, the Spark 4+ version of the shim implements this + * trait, and shims for older versions do not. + */ +@DeveloperApi +trait ConnectConversions { + implicit def castToImpl(session: api.SparkSession): SparkSession = + session.asInstanceOf[SparkSession] + + implicit def castToImpl[T](ds: api.Dataset[T]): Dataset[T] = + ds.asInstanceOf[Dataset[T]] + + implicit def castToImpl(rgds: api.RelationalGroupedDataset): RelationalGroupedDataset = + rgds.asInstanceOf[RelationalGroupedDataset] + + implicit def castToImpl[K, V]( + kvds: api.KeyValueGroupedDataset[K, V]): KeyValueGroupedDataset[K, V] = + kvds.asInstanceOf[KeyValueGroupedDataset[K, V]] +} + +object ConnectConversions extends ConnectConversions diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala similarity index 68% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala rename to connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala index f77dd512ef257..7578e2424fb42 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala @@ -14,10 +14,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.internal import org.apache.spark.connect.proto.{ConfigRequest, ConfigResponse, KeyValue} import org.apache.spark.internal.Logging +import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.connect.client.SparkConnectClient /** @@ -25,61 +26,31 @@ import org.apache.spark.sql.connect.client.SparkConnectClient * * @since 3.4.0 */ -class RuntimeConfig private[sql] (client: SparkConnectClient) extends Logging { +class ConnectRuntimeConfig private[sql] (client: SparkConnectClient) + extends RuntimeConfig + with Logging { - /** - * Sets the given Spark runtime configuration property. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def set(key: String, value: String): Unit = { executeConfigRequest { builder => builder.getSetBuilder.addPairsBuilder().setKey(key).setValue(value) } } - /** - * Sets the given Spark runtime configuration property. - * - * @since 3.4.0 - */ - def set(key: String, value: Boolean): Unit = set(key, String.valueOf(value)) - - /** - * Sets the given Spark runtime configuration property. - * - * @since 3.4.0 - */ - def set(key: String, value: Long): Unit = set(key, String.valueOf(value)) - - /** - * Returns the value of Spark runtime configuration property for the given key. - * - * @throws java.util.NoSuchElementException - * if the key is not set and does not have a default value - * @since 3.4.0 - */ + /** @inheritdoc */ @throws[NoSuchElementException]("if the key is not set") def get(key: String): String = getOption(key).getOrElse { throw new NoSuchElementException(key) } - /** - * Returns the value of Spark runtime configuration property for the given key. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def get(key: String, default: String): String = { executeConfigRequestSingleValue { builder => builder.getGetWithDefaultBuilder.addPairsBuilder().setKey(key).setValue(default) } } - /** - * Returns all properties set in this conf. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def getAll: Map[String, String] = { val response = executeConfigRequest { builder => builder.getGetAllBuilder @@ -92,11 +63,7 @@ class RuntimeConfig private[sql] (client: SparkConnectClient) extends Logging { builder.result() } - /** - * Returns the value of Spark runtime configuration property for the given key. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def getOption(key: String): Option[String] = { val pair = executeConfigRequestSinglePair { builder => builder.getGetOptionBuilder.addKeys(key) @@ -108,27 +75,14 @@ class RuntimeConfig private[sql] (client: SparkConnectClient) extends Logging { } } - /** - * Resets the configuration property for the given key. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def unset(key: String): Unit = { executeConfigRequest { builder => builder.getUnsetBuilder.addKeys(key) } } - /** - * Indicates whether the configuration property with the given key is modifiable in the current - * session. - * - * @return - * `true` if the configuration property is modifiable. For static SQL, Spark Core, invalid - * (not existing) and other non-modifiable configuration properties, the returned value is - * `false`. - * @since 3.4.0 - */ + /** @inheritdoc */ def isModifiable(key: String): Boolean = { val modifiable = executeConfigRequestSingleValue { builder => builder.getIsModifiableBuilder.addKeys(key) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala new file mode 100644 index 0000000000000..4afa8b6d566c5 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.connect.proto +import org.apache.spark.sql.{Column, DataFrameWriterV2, Dataset} + +/** + * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 + * API. + * + * @since 3.4.0 + */ +@Experimental +final class DataFrameWriterV2Impl[T] private[sql] (table: String, ds: Dataset[T]) + extends DataFrameWriterV2[T] { + import ds.sparkSession.RichColumn + + private val builder = proto.WriteOperationV2 + .newBuilder() + .setInput(ds.plan.getRoot) + .setTableName(table) + + /** @inheritdoc */ + override def using(provider: String): this.type = { + builder.setProvider(provider) + this + } + + /** @inheritdoc */ + override def option(key: String, value: String): this.type = { + builder.putOptions(key, value) + this + } + + /** @inheritdoc */ + override def options(options: scala.collection.Map[String, String]): this.type = { + builder.putAllOptions(options.asJava) + this + } + + /** @inheritdoc */ + override def options(options: java.util.Map[String, String]): this.type = { + builder.putAllOptions(options) + this + } + + /** @inheritdoc */ + override def tableProperty(property: String, value: String): this.type = { + builder.putTableProperties(property, value) + this + } + + /** @inheritdoc */ + @scala.annotation.varargs + override def partitionedBy(column: Column, columns: Column*): this.type = { + builder.addAllPartitioningColumns((column +: columns).map(_.expr).asJava) + this + } + + /** @inheritdoc */ + @scala.annotation.varargs + override def clusterBy(colName: String, colNames: String*): this.type = { + builder.addAllClusteringColumns((colName +: colNames).asJava) + this + } + + /** @inheritdoc */ + override def create(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE) + } + + /** @inheritdoc */ + override def replace(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_REPLACE) + } + + /** @inheritdoc */ + override def createOrReplace(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE) + } + + /** @inheritdoc */ + def append(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_APPEND) + } + + /** @inheritdoc */ + def overwrite(condition: Column): Unit = { + builder.setOverwriteCondition(condition.expr) + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE) + } + + /** @inheritdoc */ + def overwritePartitions(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS) + } + + private def executeWriteOperation(mode: proto.WriteOperationV2.Mode): Unit = { + val command = proto.Command + .newBuilder() + .setWriteOperationV2(builder.setMode(mode)) + .build() + ds.sparkSession.execute(command) + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala new file mode 100644 index 0000000000000..fba3c6343558b --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkRuntimeException +import org.apache.spark.annotation.Experimental +import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.{Expression, MergeAction, MergeIntoTableCommand} +import org.apache.spark.connect.proto.MergeAction.ActionType._ +import org.apache.spark.sql.{Column, Dataset, MergeIntoWriter} +import org.apache.spark.sql.functions.expr + +/** + * `MergeIntoWriter` provides methods to define and execute merge actions based on specified + * conditions. + * + * @tparam T + * the type of data in the Dataset. + * @param table + * the name of the target table for the merge operation. + * @param ds + * the source Dataset to merge into the target table. + * @param on + * the merge condition. + * + * @since 4.0.0 + */ +@Experimental +class MergeIntoWriterImpl[T] private[sql] (table: String, ds: Dataset[T], on: Column) + extends MergeIntoWriter[T] { + import ds.sparkSession.RichColumn + + private val builder = MergeIntoTableCommand + .newBuilder() + .setTargetTableName(table) + .setSourceTablePlan(ds.plan.getRoot) + .setMergeCondition(on.expr) + + /** + * Executes the merge operation. + */ + def merge(): Unit = { + if (builder.getMatchActionsCount == 0 && + builder.getNotMatchedActionsCount == 0 && + builder.getNotMatchedBySourceActionsCount == 0) { + throw new SparkRuntimeException( + errorClass = "NO_MERGE_ACTION_SPECIFIED", + messageParameters = Map.empty) + } + ds.sparkSession.execute( + proto.Command + .newBuilder() + .setMergeIntoTableCommand(builder.setWithSchemaEvolution(schemaEvolutionEnabled)) + .build()) + } + + override protected[sql] def insertAll(condition: Option[Column]): MergeIntoWriter[T] = { + builder.addNotMatchedActions(buildMergeAction(ACTION_TYPE_INSERT_STAR, condition)) + this + } + + override protected[sql] def insert( + condition: Option[Column], + map: Map[String, Column]): MergeIntoWriter[T] = { + builder.addNotMatchedActions(buildMergeAction(ACTION_TYPE_INSERT, condition, map)) + this + } + + override protected[sql] def updateAll( + condition: Option[Column], + notMatchedBySource: Boolean): MergeIntoWriter[T] = { + appendUpdateDeleteAction( + buildMergeAction(ACTION_TYPE_UPDATE_STAR, condition), + notMatchedBySource) + } + + override protected[sql] def update( + condition: Option[Column], + map: Map[String, Column], + notMatchedBySource: Boolean): MergeIntoWriter[T] = { + appendUpdateDeleteAction( + buildMergeAction(ACTION_TYPE_UPDATE, condition, map), + notMatchedBySource) + } + + override protected[sql] def delete( + condition: Option[Column], + notMatchedBySource: Boolean): MergeIntoWriter[T] = { + appendUpdateDeleteAction(buildMergeAction(ACTION_TYPE_DELETE, condition), notMatchedBySource) + } + + private def appendUpdateDeleteAction( + action: Expression, + notMatchedBySource: Boolean): MergeIntoWriter[T] = { + if (notMatchedBySource) { + builder.addNotMatchedBySourceActions(action) + } else { + builder.addMatchActions(action) + } + this + } + + private def buildMergeAction( + actionType: MergeAction.ActionType, + condition: Option[Column], + assignments: Map[String, Column] = Map.empty): Expression = { + val builder = proto.MergeAction.newBuilder().setActionType(actionType) + condition.foreach(c => builder.setCondition(c.expr)) + assignments.foreach { case (k, v) => + builder + .addAssignmentsBuilder() + .setKey(expr(k).expr) + .setValue(v.expr) + } + Expression + .newBuilder() + .setMergeAction(builder) + .build() + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index 13a26fa79085e..29fbcc443deb9 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -18,166 +18,21 @@ package org.apache.spark.sql.streaming import java.util.UUID -import java.util.concurrent.TimeoutException import scala.jdk.CollectionConverters._ -import org.apache.spark.annotation.Evolving import org.apache.spark.connect.proto.Command import org.apache.spark.connect.proto.ExecutePlanResponse import org.apache.spark.connect.proto.StreamingQueryCommand import org.apache.spark.connect.proto.StreamingQueryCommandResult import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{api, SparkSession} -/** - * A handle to a query that is executing continuously in the background as new data arrives. All - * these methods are thread-safe. - * @since 3.5.0 - */ -@Evolving -trait StreamingQuery { - // This is a copy of StreamingQuery in sql/core/.../streaming/StreamingQuery.scala - - /** - * Returns the user-specified name of the query, or null if not specified. This name can be - * specified in the `org.apache.spark.sql.streaming.DataStreamWriter` as - * `dataframe.writeStream.queryName("query").start()`. This name, if set, must be unique across - * all active queries. - * - * @since 3.5.0 - */ - def name: String - - /** - * Returns the unique id of this query that persists across restarts from checkpoint data. That - * is, this id is generated when a query is started for the first time, and will be the same - * every time it is restarted from checkpoint data. Also see [[runId]]. - * - * @since 3.5.0 - */ - def id: UUID - - /** - * Returns the unique id of this run of the query. That is, every start/restart of a query will - * generate a unique runId. Therefore, every time a query is restarted from checkpoint, it will - * have the same [[id]] but different [[runId]]s. - */ - def runId: UUID - - /** - * Returns the `SparkSession` associated with `this`. - * - * @since 3.5.0 - */ - def sparkSession: SparkSession - - /** - * Returns `true` if this query is actively running. - * - * @since 3.5.0 - */ - def isActive: Boolean - - /** - * Returns the [[StreamingQueryException]] if the query was terminated by an exception. - * @since 3.5.0 - */ - def exception: Option[StreamingQueryException] - - /** - * Returns the current status of the query. - * - * @since 3.5.0 - */ - def status: StreamingQueryStatus - - /** - * Returns an array of the most recent [[StreamingQueryProgress]] updates for this query. The - * number of progress updates retained for each stream is configured by Spark session - * configuration `spark.sql.streaming.numRecentProgressUpdates`. - * - * @since 3.5.0 - */ - def recentProgress: Array[StreamingQueryProgress] - - /** - * Returns the most recent [[StreamingQueryProgress]] update of this streaming query. - * - * @since 3.5.0 - */ - def lastProgress: StreamingQueryProgress - - /** - * Waits for the termination of `this` query, either by `query.stop()` or by an exception. If - * the query has terminated with an exception, then the exception will be thrown. - * - * If the query has terminated, then all subsequent calls to this method will either return - * immediately (if the query was terminated by `stop()`), or throw the exception immediately (if - * the query has terminated with exception). - * - * @throws StreamingQueryException - * if the query has terminated with an exception. - * @since 3.5.0 - */ - @throws[StreamingQueryException] - def awaitTermination(): Unit - - /** - * Waits for the termination of `this` query, either by `query.stop()` or by an exception. If - * the query has terminated with an exception, then the exception will be thrown. Otherwise, it - * returns whether the query has terminated or not within the `timeoutMs` milliseconds. - * - * If the query has terminated, then all subsequent calls to this method will either return - * `true` immediately (if the query was terminated by `stop()`), or throw the exception - * immediately (if the query has terminated with exception). - * - * @throws StreamingQueryException - * if the query has terminated with an exception - * @since 3.5.0 - */ - @throws[StreamingQueryException] - def awaitTermination(timeoutMs: Long): Boolean - - /** - * Blocks until all available data in the source has been processed and committed to the sink. - * This method is intended for testing. Note that in the case of continually arriving data, this - * method may block forever. Additionally, this method is only guaranteed to block until data - * that has been synchronously appended data to a - * `org.apache.spark.sql.execution.streaming.Source` prior to invocation. (i.e. `getOffset` must - * immediately reflect the addition). - * @since 3.5.0 - */ - def processAllAvailable(): Unit - - /** - * Stops the execution of this query if it is running. This waits until the termination of the - * query execution threads or until a timeout is hit. - * - * By default stop will block indefinitely. You can configure a timeout by the configuration - * `spark.sql.streaming.stopTimeout`. A timeout of 0 (or negative) milliseconds will block - * indefinitely. If a `TimeoutException` is thrown, users can retry stopping the stream. If the - * issue persists, it is advisable to kill the Spark application. - * - * @since 3.5.0 - */ - @throws[TimeoutException] - def stop(): Unit - - /** - * Prints the physical plan to the console for debugging purposes. - * @since 3.5.0 - */ - def explain(): Unit - - /** - * Prints the physical plan to the console for debugging purposes. - * - * @param extended - * whether to do extended explain or not - * @since 3.5.0 - */ - def explain(extended: Boolean): Unit +/** @inheritdoc */ +trait StreamingQuery extends api.StreamingQuery { + + /** @inheritdoc */ + override def sparkSession: SparkSession } class RemoteStreamingQuery( diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 0926734ef4872..4a76c380f772f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -23,7 +23,7 @@ import java.util.Properties import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} -import scala.concurrent.duration.DurationInt +import scala.concurrent.duration.{DurationInt, FiniteDuration} import scala.jdk.CollectionConverters._ import org.apache.commons.io.FileUtils @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult} +import org.apache.spark.sql.connect.client.{RetryPolicy, SparkConnectClient, SparkResult} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.test.{ConnectFunSuite, IntegrationTestUtils, RemoteSparkSession, SQLHelper} @@ -1569,6 +1569,25 @@ class ClientE2ETestSuite val result = df.select(trim(col("col"), " ").as("trimmed_col")).collect() assert(result sameElements Array(Row("a"), Row("b"), Row("c"))) } + + test("SPARK-49673: new batch size, multiple batches") { + val maxBatchSize = spark.conf.get("spark.connect.grpc.arrow.maxBatchSize").dropRight(1).toInt + // Adjust client grpcMaxMessageSize to maxBatchSize (10MiB; set in RemoteSparkSession config) + val sparkWithLowerMaxMessageSize = SparkSession + .builder() + .client( + SparkConnectClient + .builder() + .userId("test") + .port(port) + .grpcMaxMessageSize(maxBatchSize) + .retryPolicy(RetryPolicy + .defaultPolicy() + .copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30, "s")))) + .build()) + .create() + assert(sparkWithLowerMaxMessageSize.range(maxBatchSize).collect().length == maxBatchSize) + } } private[sql] case class ClassData(a: String, b: Int) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 474eac138ab78..315f80e13eff7 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -552,6 +552,14 @@ class PlanGenerationTestSuite valueColumnName = "value") } + test("transpose index_column") { + simple.transpose(indexColumn = fn.col("id")) + } + + test("transpose no_index_column") { + simple.transpose() + } + test("offset") { simple.offset(1000) } @@ -1801,7 +1809,11 @@ class PlanGenerationTestSuite fn.sentences(fn.col("g")) } - functionTest("sentences with locale") { + functionTest("sentences with language") { + fn.sentences(fn.col("g"), lit("en")) + } + + functionTest("sentences with language and country") { fn.sentences(fn.col("g"), lit("en"), lit("US")) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 10b31155376fb..16f6983efb187 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -145,21 +145,6 @@ object CheckConnectJvmClientCompatibility { checkMiMaCompatibility(clientJar, protobufJar, includedRules, excludeRules) } - private lazy val mergeIntoWriterExcludeRules: Seq[ProblemFilter] = { - // Exclude some auto-generated methods in [[MergeIntoWriter]] classes. - // The incompatible changes are due to the uses of [[proto.Expression]] instead - // of [[catalyst.Expression]] in the method signature. - val classNames = Seq("WhenMatched", "WhenNotMatched", "WhenNotMatchedBySource") - val methodNames = Seq("apply", "condition", "copy", "copy$*", "unapply") - - classNames.flatMap { className => - methodNames.map { methodName => - ProblemFilters.exclude[IncompatibleSignatureProblem]( - s"org.apache.spark.sql.$className.$methodName") - } - } - } - private def checkMiMaCompatibilityWithSqlModule( clientJar: File, sqlJar: File): List[Problem] = { @@ -173,6 +158,7 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.columnar.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.classic.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.jdbc.*"), @@ -269,12 +255,6 @@ object CheckConnectJvmClientCompatibility { "org.apache.spark.sql.streaming.TestGroupState"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.streaming.TestGroupState$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.streaming.PythonStreamingQueryListener"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.streaming.PythonStreamingQueryListenerWrapper"), - ProblemFilters.exclude[MissingTypesProblem]( - "org.apache.spark.sql.streaming.StreamingQueryListener$Event"), // SQLImplicits ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits.rddToDatasetHolder"), @@ -286,10 +266,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.artifact.ArtifactManager$"), - // UDFRegistration - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.register"), - // ColumnNode conversions ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession.Converter"), @@ -304,6 +280,8 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.scalalang.*"), // UDFRegistration + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.register"), ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.UDFRegistration"), ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.UDFRegistration.log*"), @@ -320,12 +298,13 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.UDFRegistration.initializeLogIfNecessary$default$2"), - // Datasource V2 partition transforms - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.PartitionTransform$ExtractTransform")) ++ - mergeIntoWriterExcludeRules + // Protected DataFrameReader methods... + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.DataFrameReader.validateSingleVariantColumn"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.DataFrameReader.validateJsonSchema"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.DataFrameReader.validateXmlSchema")) checkMiMaCompatibility(clientJar, sqlJar, includedRules, excludeRules) } @@ -387,21 +366,6 @@ object CheckConnectJvmClientCompatibility { // Experimental ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession.registerClassFinder"), - // public - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.interruptAll"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.interruptTag"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.interruptOperation"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.addTag"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.removeTag"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.getTags"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.clearTags"), // SparkSession#Builder ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession#Builder.remote"), @@ -435,8 +399,7 @@ object CheckConnectJvmClientCompatibility { // Encoders are in the wrong JAR ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders$")) ++ - mergeIntoWriterExcludeRules + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders$")) checkMiMaCompatibility(sqlJar, clientJar, includedRules, excludeRules) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 70b471cf74b33..5397dae9dcc5f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -30,7 +30,7 @@ import org.apache.arrow.memory.{BufferAllocator, RootAllocator} import org.apache.arrow.vector.VarBinaryVector import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{sql, SparkUnsupportedOperationException} +import org.apache.spark.{sql, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql.{AnalysisException, Encoders, Row} import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec, OuterScopes} @@ -776,6 +776,16 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { } } + test("kryo serialization") { + val e = intercept[SparkRuntimeException] { + val encoder = sql.encoderFor(Encoders.kryo[(Int, String)]) + roundTripAndCheckIdentical(encoder) { () => + Iterator.tabulate(10)(i => (i, "itr_" + i)) + } + } + assert(e.getErrorClass == "CANNOT_USE_KRYO") + } + test("transforming encoder") { val schema = new StructType() .add("key", IntegerType) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala index 758262ead7f1e..27b1ee014a719 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala @@ -334,8 +334,6 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L assert(exception.getErrorClass != null) assert(exception.getMessageParameters().get("id") == query.id.toString) assert(exception.getMessageParameters().get("runId") == query.runId.toString) - assert(!exception.getMessageParameters().get("startOffset").isEmpty) - assert(!exception.getMessageParameters().get("endOffset").isEmpty) assert(exception.getCause.isInstanceOf[SparkException]) assert(exception.getCause.getCause.isInstanceOf[SparkException]) assert( @@ -374,8 +372,6 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L assert(exception.getErrorClass != null) assert(exception.getMessageParameters().get("id") == query.id.toString) assert(exception.getMessageParameters().get("runId") == query.runId.toString) - assert(!exception.getMessageParameters().get("startOffset").isEmpty) - assert(!exception.getMessageParameters().get("endOffset").isEmpty) assert(exception.getCause.isInstanceOf[SparkException]) assert(exception.getCause.getCause.isInstanceOf[SparkException]) assert( diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala index a878e42b40aa7..36aaa2cc7fbf6 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala @@ -24,6 +24,9 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration.FiniteDuration import org.scalatest.{BeforeAndAfterAll, Suite} +import org.scalatest.concurrent.Eventually.eventually +import org.scalatest.concurrent.Futures.timeout +import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkBuildInfo import org.apache.spark.sql.SparkSession @@ -121,6 +124,8 @@ object SparkConnectServerUtils { // to make the tests exercise reattach. "spark.connect.execute.reattachable.senderMaxStreamDuration=1s", "spark.connect.execute.reattachable.senderMaxStreamSize=123", + // Testing SPARK-49673, setting maxBatchSize to 10MiB + s"spark.connect.grpc.arrow.maxBatchSize=${10 * 1024 * 1024}", // Disable UI "spark.ui.enabled=false") Seq("--jars", catalystTestJar) ++ confs.flatMap(v => "--conf" :: v :: Nil) @@ -184,12 +189,14 @@ object SparkConnectServerUtils { .port(port) .retryPolicy(RetryPolicy .defaultPolicy() - .copy(maxRetries = Some(7), maxBackoff = Some(FiniteDuration(10, "s")))) + .copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30, "s")))) .build()) .create() // Execute an RPC which will get retried until the server is up. - assert(spark.version == SparkBuildInfo.spark_version) + eventually(timeout(1.minute)) { + assert(spark.version == SparkBuildInfo.spark_version) + } // Auto-sync dependencies. SparkConnectServerUtils.syncTestDependencies(spark) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index 30009c03c49fd..90cd68e6e1d24 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -490,7 +490,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { .option("query", "SELECT @myvariant1 as variant1, @myvariant2 as variant2") .load() }, - errorClass = "UNRECOGNIZED_SQL_TYPE", + condition = "UNRECOGNIZED_SQL_TYPE", parameters = Map("typeName" -> "sql_variant", "jdbcType" -> "-156")) } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala index b337eb2fc9b3b..91a82075a3607 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala @@ -87,7 +87,7 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"DOUBLE\"", "newType" -> "\"VARCHAR(10)\"", diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index 27ec98e9ac451..e5fd453cb057c 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -97,7 +97,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"STRING\"", "newType" -> "\"INT\"", @@ -115,7 +115,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD exception = intercept[SparkSQLFeatureNotSupportedException] { sql(s"ALTER TABLE $tbl ALTER COLUMN ID DROP NOT NULL") }, - errorClass = "_LEGACY_ERROR_TEMP_2271") + condition = "_LEGACY_ERROR_TEMP_2271") } test("SPARK-47440: SQLServer does not support boolean expression in binary comparison") { diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index 81aacf2c14d7a..700c05b54a256 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -77,8 +77,19 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest s"""CREATE TABLE pattern_testing_table ( |pattern_testing_col LONGTEXT |) - """.stripMargin + |""".stripMargin ).executeUpdate() + connection.prepareStatement( + "CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)") + .executeUpdate() + } + + override def dataPreparation(connection: Connection): Unit = { + super.dataPreparation(connection) + connection.prepareStatement("INSERT INTO datetime VALUES " + + "('amy', '2022-05-19', '2022-05-19 00:00:00')").executeUpdate() + connection.prepareStatement("INSERT INTO datetime VALUES " + + "('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { @@ -98,7 +109,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"STRING\"", "newType" -> "\"INT\"", @@ -131,7 +142,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest exception = intercept[SparkSQLFeatureNotSupportedException] { sql(s"ALTER TABLE $tbl ALTER COLUMN ID DROP NOT NULL") }, - errorClass = "_LEGACY_ERROR_TEMP_2271") + condition = "_LEGACY_ERROR_TEMP_2271") } override def testCreateTableWithProperty(tbl: String): Unit = { @@ -157,6 +168,79 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest assert(sql(s"SELECT char_length(c1) from $tableName").head().get(0) === 65536) } } + + override def testDatetime(tbl: String): Unit = { + val df1 = sql(s"SELECT name FROM $tbl WHERE " + + "dayofyear(date1) > 100 AND dayofmonth(date1) > 10 ") + checkFilterPushed(df1) + val rows1 = df1.collect() + assert(rows1.length === 2) + assert(rows1(0).getString(0) === "amy") + assert(rows1(1).getString(0) === "alex") + + val df2 = sql(s"SELECT name FROM $tbl WHERE year(date1) = 2022 AND quarter(date1) = 2") + checkFilterPushed(df2) + val rows2 = df2.collect() + assert(rows2.length === 2) + assert(rows2(0).getString(0) === "amy") + assert(rows2(1).getString(0) === "alex") + + val df3 = sql(s"SELECT name FROM $tbl WHERE second(time1) = 0 AND month(date1) = 5") + checkFilterPushed(df3) + val rows3 = df3.collect() + assert(rows3.length === 2) + assert(rows3(0).getString(0) === "amy") + assert(rows3(1).getString(0) === "alex") + + val df4 = sql(s"SELECT name FROM $tbl WHERE hour(time1) = 0 AND minute(time1) = 0") + checkFilterPushed(df4) + val rows4 = df4.collect() + assert(rows4.length === 2) + assert(rows4(0).getString(0) === "amy") + assert(rows4(1).getString(0) === "alex") + + val df5 = sql(s"SELECT name FROM $tbl WHERE " + + "extract(WEEk from date1) > 10 AND extract(YEAROFWEEK from date1) = 2022") + checkFilterPushed(df5) + val rows5 = df5.collect() + assert(rows5.length === 2) + assert(rows5(0).getString(0) === "amy") + assert(rows5(1).getString(0) === "alex") + + val df6 = sql(s"SELECT name FROM $tbl WHERE date_add(date1, 1) = date'2022-05-20' " + + "AND datediff(date1, '2022-05-10') > 0") + checkFilterPushed(df6) + val rows6 = df6.collect() + assert(rows6.length === 1) + assert(rows6(0).getString(0) === "amy") + + val df7 = sql(s"SELECT name FROM $tbl WHERE weekday(date1) = 2") + checkFilterPushed(df7) + val rows7 = df7.collect() + assert(rows7.length === 1) + assert(rows7(0).getString(0) === "alex") + + val df8 = sql(s"SELECT name FROM $tbl WHERE dayofweek(date1) = 4") + checkFilterPushed(df8) + val rows8 = df8.collect() + assert(rows8.length === 1) + assert(rows8(0).getString(0) === "alex") + + val df9 = sql(s"SELECT name FROM $tbl WHERE " + + "dayofyear(date1) > 100 order by dayofyear(date1) limit 1") + checkFilterPushed(df9) + val rows9 = df9.collect() + assert(rows9.length === 1) + assert(rows9(0).getString(0) === "alex") + + // MySQL does not support + val df10 = sql(s"SELECT name FROM $tbl WHERE trunc(date1, 'week') = date'2022-05-16'") + checkFilterPushed(df10, false) + val rows10 = df10.collect() + assert(rows10.length === 2) + assert(rows10(0).getString(0) === "amy") + assert(rows10(1).getString(0) === "alex") + } } /** diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala index 42d82233b421b..5e40f0bbc4554 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala @@ -62,7 +62,7 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac exception = intercept[SparkSQLFeatureNotSupportedException] { catalog.createNamespace(Array("foo"), Map("comment" -> "test comment").asJava) }, - errorClass = "UNSUPPORTED_FEATURE.COMMENT_NAMESPACE", + condition = "UNSUPPORTED_FEATURE.COMMENT_NAMESPACE", parameters = Map("namespace" -> "`foo`") ) assert(catalog.namespaceExists(Array("foo")) === false) @@ -74,7 +74,7 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac Array("foo"), NamespaceChange.setProperty("comment", "comment for foo")) }, - errorClass = "UNSUPPORTED_FEATURE.COMMENT_NAMESPACE", + condition = "UNSUPPORTED_FEATURE.COMMENT_NAMESPACE", parameters = Map("namespace" -> "`foo`") ) @@ -82,7 +82,7 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac exception = intercept[SparkSQLFeatureNotSupportedException] { catalog.alterNamespace(Array("foo"), NamespaceChange.removeProperty("comment")) }, - errorClass = "UNSUPPORTED_FEATURE.REMOVE_NAMESPACE_COMMENT", + condition = "UNSUPPORTED_FEATURE.REMOVE_NAMESPACE_COMMENT", parameters = Map("namespace" -> "`foo`") ) @@ -90,7 +90,7 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac exception = intercept[SparkSQLFeatureNotSupportedException] { catalog.dropNamespace(Array("foo"), cascade = false) }, - errorClass = "UNSUPPORTED_FEATURE.DROP_NAMESPACE", + condition = "UNSUPPORTED_FEATURE.DROP_NAMESPACE", parameters = Map("namespace" -> "`foo`") ) catalog.dropNamespace(Array("foo"), cascade = true) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala index 342fb4bb38e60..2c97a588670a8 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala @@ -118,7 +118,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"DECIMAL(19,0)\"", "newType" -> "\"INT\"", @@ -139,7 +139,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes exception = intercept[SparkRuntimeException] { sql(s"INSERT INTO $tableName SELECT rpad('hi', 256, 'spark')") }, - errorClass = "EXCEED_LIMIT_LENGTH", + condition = "EXCEED_LIMIT_LENGTH", parameters = Map("limit" -> "255") ) } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index e22136a09a56c..850391e8dc33c 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -84,7 +84,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"STRING\"", "newType" -> "\"INT\"", @@ -118,7 +118,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT sql(s"CREATE TABLE $t2(c int)") checkError( exception = intercept[TableAlreadyExistsException](sql(s"ALTER TABLE $t1 RENAME TO t2")), - errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> "`t2`") ) } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala index e4cc88cec0f5e..3b1a457214be7 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala @@ -92,7 +92,7 @@ private[v2] trait V2JDBCNamespaceTest extends SharedSparkSession with DockerInte catalog.listNamespaces(Array("foo")) } checkError(e, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> "`foo`")) } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index b0ab614b27d1f..54635f69f8b65 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -71,7 +71,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`bad_column`", @@ -92,11 +92,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu private def checkErrorFailedJDBC( e: AnalysisException, - errorClass: String, + condition: String, tbl: String): Unit = { checkErrorMatchPVals( exception = e, - errorClass = errorClass, + condition = condition, parameters = Map( "url" -> "jdbc:.*", "tableName" -> s"`$tbl`") @@ -126,7 +126,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[AnalysisException] { sql(s"ALTER TABLE $catalogName.alt_table ADD COLUMNS (C3 DOUBLE)") }, - errorClass = "FIELD_ALREADY_EXISTS", + condition = "FIELD_ALREADY_EXISTS", parameters = Map( "op" -> "add", "fieldNames" -> "`C3`", @@ -159,7 +159,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`bad_column`", @@ -182,7 +182,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`bad_column`", @@ -206,7 +206,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[AnalysisException] { sql(s"ALTER TABLE $catalogName.alt_table RENAME COLUMN ID1 TO ID2") }, - errorClass = "FIELD_ALREADY_EXISTS", + condition = "FIELD_ALREADY_EXISTS", parameters = Map( "op" -> "rename", "fieldNames" -> "`ID2`", @@ -308,7 +308,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[IndexAlreadyExistsException] { sql(s"CREATE index i1 ON $catalogName.new_table (col1)") }, - errorClass = "INDEX_ALREADY_EXISTS", + condition = "INDEX_ALREADY_EXISTS", parameters = Map("indexName" -> "`i1`", "tableName" -> "`new_table`") ) @@ -333,7 +333,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[NoSuchIndexException] { sql(s"DROP index i1 ON $catalogName.new_table") }, - errorClass = "INDEX_NOT_FOUND", + condition = "INDEX_NOT_FOUND", parameters = Map("indexName" -> "`i1`", "tableName" -> "`new_table`") ) } @@ -353,7 +353,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - private def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = { + protected def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = { val filter = df.queryExecution.optimizedPlan.collect { case f: Filter => f } @@ -975,9 +975,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[AnalysisException] { sql(s"ALTER TABLE $catalogName.tbl2 RENAME TO tbl1") }, - errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> "`tbl1`") ) } } + + def testDatetime(tbl: String): Unit = {} + + test("scan with filter push-down with date time functions") { + testDatetime(s"$catalogAndNamespace.${caseConvert("datetime")}") + } } diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala index 56456f9b1f776..8d0bcc5816775 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala @@ -50,7 +50,7 @@ private[kafka010] class KafkaRecordToRowConverter { new GenericArrayData(cr.headers.iterator().asScala .map(header => InternalRow(UTF8String.fromString(header.key()), header.value()) - ).toArray) + ).toArray[Any]) } else { null } diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 9ae6a9290f80a..1d119de43970f 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -1156,7 +1156,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase with test("allow group.id prefix") { // Group ID prefix is only supported by consumer based offset reader - if (spark.conf.get(SQLConf.USE_DEPRECATED_KAFKA_OFFSET_FETCHING)) { + if (sqlConf.getConf(SQLConf.USE_DEPRECATED_KAFKA_OFFSET_FETCHING)) { testGroupId("groupIdPrefix", (expected, actual) => { assert(actual.exists(_.startsWith(expected)) && !actual.exists(_ === expected), "Valid consumer groups don't contain the expected group id - " + @@ -1167,7 +1167,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase with test("allow group.id override") { // Group ID override is only supported by consumer based offset reader - if (spark.conf.get(SQLConf.USE_DEPRECATED_KAFKA_OFFSET_FETCHING)) { + if (sqlConf.getConf(SQLConf.USE_DEPRECATED_KAFKA_OFFSET_FETCHING)) { testGroupId("kafka.group.id", (expected, actual) => { assert(actual.exists(_ === expected), "Valid consumer groups don't " + s"contain the expected group id - Valid consumer groups: $actual / " + diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderSuite.scala index 320485a79e59d..6fc22e7ac5e03 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderSuite.scala @@ -153,7 +153,7 @@ class KafkaOffsetReaderSuite extends QueryTest with SharedSparkSession with Kafk } checkError( exception = ex, - errorClass = "KAFKA_START_OFFSET_DOES_NOT_MATCH_ASSIGNED", + condition = "KAFKA_START_OFFSET_DOES_NOT_MATCH_ASSIGNED", parameters = Map( "specifiedPartitions" -> "Set\\(.*,.*\\)", "assignedPartitions" -> "Set\\(.*,.*,.*\\)"), diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala index 31050887936bd..3b0def8fc73f7 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala @@ -20,7 +20,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.Column -import org.apache.spark.sql.internal.ExpressionUtils.{column, expression} +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.protobuf.utils.ProtobufUtils // scalastyle:off: object.name @@ -71,7 +71,13 @@ object functions { messageName: String, binaryFileDescriptorSet: Array[Byte], options: java.util.Map[String, String]): Column = { - ProtobufDataToCatalyst(data, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap) + Column.fnWithOptions( + "from_protobuf", + options.asScala.iterator, + data, + lit(messageName), + lit(binaryFileDescriptorSet) + ) } /** @@ -90,7 +96,7 @@ object functions { @Experimental def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = { val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - ProtobufDataToCatalyst(data, messageName, Some(fileContent)) + from_protobuf(data, messageName, fileContent) } /** @@ -109,7 +115,12 @@ object functions { @Experimental def from_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte]) : Column = { - ProtobufDataToCatalyst(data, messageName, Some(binaryFileDescriptorSet)) + Column.fn( + "from_protobuf", + data, + lit(messageName), + lit(binaryFileDescriptorSet) + ) } /** @@ -129,7 +140,11 @@ object functions { */ @Experimental def from_protobuf(data: Column, messageClassName: String): Column = { - ProtobufDataToCatalyst(data, messageClassName) + Column.fn( + "from_protobuf", + data, + lit(messageClassName) + ) } /** @@ -153,7 +168,12 @@ object functions { data: Column, messageClassName: String, options: java.util.Map[String, String]): Column = { - ProtobufDataToCatalyst(data, messageClassName, None, options.asScala.toMap) + Column.fnWithOptions( + "from_protobuf", + options.asScala.iterator, + data, + lit(messageClassName) + ) } /** @@ -191,7 +211,12 @@ object functions { @Experimental def to_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte]) : Column = { - CatalystDataToProtobuf(data, messageName, Some(binaryFileDescriptorSet)) + Column.fn( + "to_protobuf", + data, + lit(messageName), + lit(binaryFileDescriptorSet) + ) } /** * Converts a column into binary of protobuf format. The Protobuf definition is provided @@ -213,7 +238,7 @@ object functions { descFilePath: String, options: java.util.Map[String, String]): Column = { val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - CatalystDataToProtobuf(data, messageName, Some(fileContent), options.asScala.toMap) + to_protobuf(data, messageName, fileContent, options) } /** @@ -237,7 +262,13 @@ object functions { binaryFileDescriptorSet: Array[Byte], options: java.util.Map[String, String] ): Column = { - CatalystDataToProtobuf(data, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap) + Column.fnWithOptions( + "to_protobuf", + options.asScala.iterator, + data, + lit(messageName), + lit(binaryFileDescriptorSet) + ) } /** @@ -257,7 +288,11 @@ object functions { */ @Experimental def to_protobuf(data: Column, messageClassName: String): Column = { - CatalystDataToProtobuf(data, messageClassName) + Column.fn( + "to_protobuf", + data, + lit(messageClassName) + ) } /** @@ -279,6 +314,11 @@ object functions { @Experimental def to_protobuf(data: Column, messageClassName: String, options: java.util.Map[String, String]) : Column = { - CatalystDataToProtobuf(data, messageClassName, None, options.asScala.toMap) + Column.fnWithOptions( + "to_protobuf", + options.asScala.iterator, + data, + lit(messageClassName) + ) } } diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala index 6644bce98293b..e85097a272f24 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala @@ -43,8 +43,8 @@ private[sql] class ProtobufOptions( /** * Adds support for recursive fields. If this option is is not specified, recursive fields are - * not permitted. Setting it to 0 drops the recursive fields, 1 allows it to be recursed once, - * and 2 allows it to be recursed twice and so on, up to 10. Values larger than 10 are not + * not permitted. Setting it to 1 drops the recursive fields, 0 allows it to be recursed once, + * and 3 allows it to be recursed twice and so on, up to 10. Values larger than 10 are not * allowed in order avoid inadvertently creating very large schemas. If a Protobuf message * has depth beyond this limit, the Spark struct returned is truncated after the recursion limit. * @@ -52,8 +52,8 @@ private[sql] class ProtobufOptions( * `message Person { string name = 1; Person friend = 2; }` * The following lists the schema with different values for this setting. * 1: `struct` - * 2: `struct>` - * 3: `struct>>` + * 2: `struct>` + * 3: `struct>>` * and so on. */ val recursiveFieldMaxDepth: Int = parameters.getOrElse("recursive.fields.max.depth", "-1").toInt @@ -181,7 +181,7 @@ private[sql] class ProtobufOptions( val upcastUnsignedInts: Boolean = parameters.getOrElse("upcast.unsigned.ints", false.toString).toBoolean - // Whether to unwrap the struct representation for well known primitve wrapper types when + // Whether to unwrap the struct representation for well known primitive wrapper types when // deserializing. By default, the wrapper types for primitives (i.e. google.protobuf.Int32Value, // google.protobuf.Int64Value, etc.) will get deserialized as structs. We allow the option to // deserialize them as their respective primitives. @@ -221,7 +221,7 @@ private[sql] class ProtobufOptions( // By default, in the spark schema field a will be dropped, which result in schema // b struct // If retain.empty.message.types=true, field a will be retained by inserting a dummy column. - // b struct, name: string> + // b struct, name: string> val retainEmptyMessage: Boolean = parameters.getOrElse("retain.empty.message.types", false.toString).toBoolean } diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala index 4ff432cf7a055..3eaa91e472c43 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala @@ -708,7 +708,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } checkError( exception = e, - errorClass = "PROTOBUF_DEPENDENCY_NOT_FOUND", + condition = "PROTOBUF_DEPENDENCY_NOT_FOUND", parameters = Map("dependencyName" -> "nestedenum.proto")) } @@ -1057,7 +1057,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } checkError( ex, - errorClass = "PROTOBUF_DESCRIPTOR_FILE_NOT_FOUND", + condition = "PROTOBUF_DESCRIPTOR_FILE_NOT_FOUND", parameters = Map("filePath" -> "/non/existent/path.desc") ) assert(ex.getCause != null) @@ -1699,7 +1699,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } checkError( exception = parseError, - errorClass = "CANNOT_CONVERT_SQL_VALUE_TO_PROTOBUF_ENUM_TYPE", + condition = "CANNOT_CONVERT_SQL_VALUE_TO_PROTOBUF_ENUM_TYPE", parameters = Map( "sqlColumn" -> "`basic_enum`", "protobufColumn" -> "field 'basic_enum'", @@ -1711,7 +1711,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } checkError( exception = parseError, - errorClass = "CANNOT_CONVERT_SQL_VALUE_TO_PROTOBUF_ENUM_TYPE", + condition = "CANNOT_CONVERT_SQL_VALUE_TO_PROTOBUF_ENUM_TYPE", parameters = Map( "sqlColumn" -> "`basic_enum`", "protobufColumn" -> "field 'basic_enum'", @@ -2093,7 +2093,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | to_protobuf(complex_struct, 42, '$testFileDescFile', map()) |FROM protobuf_test_table |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map( "sqlExpr" -> s"""\"to_protobuf(complex_struct, 42, $testFileDescFile, map())\"""", "msg" -> ("The second argument of the TO_PROTOBUF SQL function must be a constant " + @@ -2111,11 +2111,11 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | to_protobuf(complex_struct, 'SimpleMessageJavaTypes', 42, map()) |FROM protobuf_test_table |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map( "sqlExpr" -> "\"to_protobuf(complex_struct, SimpleMessageJavaTypes, 42, map())\"", "msg" -> ("The third argument of the TO_PROTOBUF SQL function must be a constant " + - "string representing the Protobuf descriptor file path"), + "string or binary data representing the Protobuf descriptor file path"), "hint" -> ""), queryContext = Array(ExpectedContext( fragment = "to_protobuf(complex_struct, 'SimpleMessageJavaTypes', 42, map())", @@ -2130,7 +2130,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | to_protobuf(complex_struct, 'SimpleMessageJavaTypes', '$testFileDescFile', 42) |FROM protobuf_test_table |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map( "sqlExpr" -> s"""\"to_protobuf(complex_struct, SimpleMessageJavaTypes, $testFileDescFile, 42)\"""", @@ -2152,7 +2152,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot |SELECT from_protobuf(protobuf_data, 42, '$testFileDescFile', map()) |FROM ($toProtobufSql) |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map( "sqlExpr" -> s"""\"from_protobuf(protobuf_data, 42, $testFileDescFile, map())\"""", "msg" -> ("The second argument of the FROM_PROTOBUF SQL function must be a constant " + @@ -2169,11 +2169,11 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot |SELECT from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', 42, map()) |FROM ($toProtobufSql) |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map( "sqlExpr" -> "\"from_protobuf(protobuf_data, SimpleMessageJavaTypes, 42, map())\"", "msg" -> ("The third argument of the FROM_PROTOBUF SQL function must be a constant " + - "string representing the Protobuf descriptor file path"), + "string or binary data representing the Protobuf descriptor file path"), "hint" -> ""), queryContext = Array(ExpectedContext( fragment = "from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', 42, map())", @@ -2188,7 +2188,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', '$testFileDescFile', 42) |FROM ($toProtobufSql) |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map( "sqlExpr" -> s"""\"from_protobuf(protobuf_data, SimpleMessageJavaTypes, $testFileDescFile, 42)\"""", diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala index 03285c73f1ff1..2737bb9feb3ad 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala @@ -95,7 +95,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { protoFile, Deserializer, fieldMatch, - errorClass = "CANNOT_CONVERT_PROTOBUF_MESSAGE_TYPE_TO_SQL_TYPE", + condition = "CANNOT_CONVERT_PROTOBUF_MESSAGE_TYPE_TO_SQL_TYPE", params = Map( "protobufType" -> "MissMatchTypeInRoot", "toType" -> toSQLType(CATALYST_STRUCT))) @@ -104,7 +104,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { protoFile, Serializer, fieldMatch, - errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", + condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( "protobufType" -> "MissMatchTypeInRoot", "toType" -> toSQLType(CATALYST_STRUCT))) @@ -122,7 +122,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { protoFile, Serializer, BY_NAME, - errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", + condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( "protobufType" -> "FieldMissingInProto", "toType" -> toSQLType(CATALYST_STRUCT))) @@ -132,7 +132,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { Serializer, BY_NAME, nonnullCatalyst, - errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", + condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( "protobufType" -> "FieldMissingInProto", "toType" -> toSQLType(nonnullCatalyst))) @@ -150,7 +150,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { Deserializer, fieldMatch, catalyst, - errorClass = "CANNOT_CONVERT_PROTOBUF_MESSAGE_TYPE_TO_SQL_TYPE", + condition = "CANNOT_CONVERT_PROTOBUF_MESSAGE_TYPE_TO_SQL_TYPE", params = Map( "protobufType" -> "MissMatchTypeInDeepNested", "toType" -> toSQLType(catalyst))) @@ -160,7 +160,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { Serializer, fieldMatch, catalyst, - errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", + condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( "protobufType" -> "MissMatchTypeInDeepNested", "toType" -> toSQLType(catalyst))) @@ -177,7 +177,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { Serializer, BY_NAME, catalystSchema = foobarSQLType, - errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", + condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( "protobufType" -> "FoobarWithRequiredFieldBar", "toType" -> toSQLType(foobarSQLType))) @@ -199,7 +199,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { Serializer, BY_NAME, catalystSchema = nestedFoobarSQLType, - errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", + condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( "protobufType" -> "NestedFoobarWithRequiredFieldBar", "toType" -> toSQLType(nestedFoobarSQLType))) @@ -222,7 +222,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { checkError( exception = e1, - errorClass = "CANNOT_PARSE_PROTOBUF_DESCRIPTOR") + condition = "CANNOT_PARSE_PROTOBUF_DESCRIPTOR") val basicMessageDescWithoutImports = descriptorSetWithoutImports( ProtobufUtils.readDescriptorFileContent( @@ -240,7 +240,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { checkError( exception = e2, - errorClass = "PROTOBUF_DEPENDENCY_NOT_FOUND", + condition = "PROTOBUF_DEPENDENCY_NOT_FOUND", parameters = Map("dependencyName" -> "nestedenum.proto")) } @@ -254,7 +254,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { serdeFactory: SerdeFactory[_], fieldMatchType: MatchType, catalystSchema: StructType = CATALYST_STRUCT, - errorClass: String, + condition: String, params: Map[String, String]): Unit = { val e = intercept[AnalysisException] { @@ -274,7 +274,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { assert(e.getMessage === expectMsg) checkError( exception = e, - errorClass = errorClass, + condition = condition, parameters = params) } diff --git a/core/pom.xml b/core/pom.xml index 0a339e11a5d20..19f58940ed942 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -32,7 +32,6 @@ core - **/OpenTelemetry*.scala @@ -122,19 +121,14 @@ io.jsonwebtoken jjwt-api - 0.12.6 io.jsonwebtoken jjwt-impl - 0.12.6 - test io.jsonwebtoken jjwt-jackson - 0.12.6 - test @@ -627,34 +613,10 @@ - opentelemetry + jjwt - + compile - - - io.opentelemetry - opentelemetry-exporter-otlp - 1.41.0 - - - io.opentelemetry - opentelemetry-sdk-extension-autoconfigure-spi - - - - - io.opentelemetry - opentelemetry-sdk - 1.41.0 - - - com.squareup.okhttp3 - okhttp - 3.12.12 - test - - sparkr diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index 4e251a1c2901b..412d612c7f1d5 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -17,6 +17,7 @@ package org.apache.spark.io; import org.apache.spark.storage.StorageUtils; +import org.apache.spark.unsafe.Platform; import java.io.File; import java.io.IOException; @@ -47,7 +48,7 @@ public final class NioBufferedFileInputStream extends InputStream { private final FileChannel fileChannel; public NioBufferedFileInputStream(File file, int bufferSizeInBytes) throws IOException { - byteBuffer = ByteBuffer.allocateDirect(bufferSizeInBytes); + byteBuffer = Platform.allocateDirectBuffer(bufferSizeInBytes); fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ); byteBuffer.flip(); this.cleanable = CLEANER.register(this, new ResourceCleaner(fileChannel, byteBuffer)); diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 485f0abcd25ee..042179d86c31a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -27,6 +27,7 @@ import scala.collection.Map import scala.collection.concurrent.{Map => ScalaConcurrentMap} import scala.collection.immutable import scala.collection.mutable.HashMap +import scala.concurrent.{Future, Promise} import scala.jdk.CollectionConverters._ import scala.reflect.{classTag, ClassTag} import scala.util.control.NonFatal @@ -909,10 +910,20 @@ class SparkContext(config: SparkConf) extends Logging { * * @since 3.5.0 */ - def addJobTag(tag: String): Unit = { - SparkContext.throwIfInvalidTag(tag) + def addJobTag(tag: String): Unit = addJobTags(Set(tag)) + + /** + * Add multiple tags to be assigned to all the jobs started by this thread. + * See [[addJobTag]] for more details. + * + * @param tags The tags to be added. Cannot contain ',' (comma) character. + * + * @since 4.0.0 + */ + def addJobTags(tags: Set[String]): Unit = { + tags.foreach(SparkContext.throwIfInvalidTag) val existingTags = getJobTags() - val newTags = (existingTags + tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP) + val newTags = (existingTags ++ tags).mkString(SparkContext.SPARK_JOB_TAGS_SEP) setLocalProperty(SparkContext.SPARK_JOB_TAGS, newTags) } @@ -924,10 +935,20 @@ class SparkContext(config: SparkConf) extends Logging { * * @since 3.5.0 */ - def removeJobTag(tag: String): Unit = { - SparkContext.throwIfInvalidTag(tag) + def removeJobTag(tag: String): Unit = removeJobTags(Set(tag)) + + /** + * Remove multiple tags to be assigned to all the jobs started by this thread. + * See [[removeJobTag]] for more details. + * + * @param tags The tags to be removed. Cannot contain ',' (comma) character. + * + * @since 4.0.0 + */ + def removeJobTags(tags: Set[String]): Unit = { + tags.foreach(SparkContext.throwIfInvalidTag) val existingTags = getJobTags() - val newTags = (existingTags - tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP) + val newTags = (existingTags -- tags).mkString(SparkContext.SPARK_JOB_TAGS_SEP) if (newTags.isEmpty) { clearJobTags() } else { @@ -2684,6 +2705,25 @@ class SparkContext(config: SparkConf) extends Logging { dagScheduler.cancelJobGroup(groupId, cancelFutureJobs = true, None) } + /** + * Cancel active jobs that have the specified tag. See `org.apache.spark.SparkContext.addJobTag`. + * + * @param tag The tag to be cancelled. Cannot contain ',' (comma) character. + * @param reason reason for cancellation. + * @return A future with [[ActiveJob]]s, allowing extraction of information such as Job ID and + * tags. + */ + private[spark] def cancelJobsWithTagWithFuture( + tag: String, + reason: String): Future[Seq[ActiveJob]] = { + SparkContext.throwIfInvalidTag(tag) + assertNotStopped() + + val cancelledJobs = Promise[Seq[ActiveJob]]() + dagScheduler.cancelJobsWithTag(tag, Some(reason), Some(cancelledJobs)) + cancelledJobs.future + } + /** * Cancel active jobs that have the specified tag. See `org.apache.spark.SparkContext.addJobTag`. * @@ -2695,7 +2735,7 @@ class SparkContext(config: SparkConf) extends Logging { def cancelJobsWithTag(tag: String, reason: String): Unit = { SparkContext.throwIfInvalidTag(tag) assertNotStopped() - dagScheduler.cancelJobsWithTag(tag, Option(reason)) + dagScheduler.cancelJobsWithTag(tag, Option(reason), cancelledJobs = None) } /** @@ -2708,7 +2748,7 @@ class SparkContext(config: SparkConf) extends Logging { def cancelJobsWithTag(tag: String): Unit = { SparkContext.throwIfInvalidTag(tag) assertNotStopped() - dagScheduler.cancelJobsWithTag(tag, None) + dagScheduler.cancelJobsWithTag(tag, reason = None, cancelledJobs = None) } /** Cancel all jobs that have been scheduled or are running. */ diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index f8b7cdcf7a8b0..6b7fc1b0804b7 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -73,7 +73,7 @@ class SparkEnv ( // We initialize the ShuffleManager later in SparkContext and Executor to allow // user jars to define custom ShuffleManagers. - private var _shuffleManager: ShuffleManager = _ + @volatile private var _shuffleManager: ShuffleManager = _ def shuffleManager: ShuffleManager = _shuffleManager diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 667bf8bbc9754..e9507fa6bee48 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -25,6 +25,9 @@ import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import scala.util.Try +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark.{SparkConf, SparkUserAppException} import org.apache.spark.api.python.{Py4JServer, PythonUtils} import org.apache.spark.internal.config._ @@ -50,18 +53,21 @@ object PythonRunner { val formattedPythonFile = formatPath(pythonFile) val formattedPyFiles = resolvePyFiles(formatPaths(pyFiles)) - val gatewayServer = new Py4JServer(sparkConf) + var gatewayServer: Option[Py4JServer] = None + if (sparkConf.getOption("spark.remote").isEmpty) { + gatewayServer = Some(new Py4JServer(sparkConf)) - val thread = new Thread(() => Utils.logUncaughtExceptions { gatewayServer.start() }) - thread.setName("py4j-gateway-init") - thread.setDaemon(true) - thread.start() + val thread = new Thread(() => Utils.logUncaughtExceptions { gatewayServer.get.start() }) + thread.setName("py4j-gateway-init") + thread.setDaemon(true) + thread.start() - // Wait until the gateway server has started, so that we know which port is it bound to. - // `gatewayServer.start()` will start a new thread and run the server code there, after - // initializing the socket, so the thread started above will end as soon as the server is - // ready to serve connections. - thread.join() + // Wait until the gateway server has started, so that we know which port is it bound to. + // `gatewayServer.start()` will start a new thread and run the server code there, after + // initializing the socket, so the thread started above will end as soon as the server is + // ready to serve connections. + thread.join() + } // Build up a PYTHONPATH that includes the Spark assembly (where this class is), the // python directories in SPARK_HOME (if set), and any files in the pyFiles argument @@ -74,12 +80,22 @@ object PythonRunner { // Launch Python process val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava) val env = builder.environment() + if (sparkConf.getOption("spark.remote").nonEmpty) { + // For non-local remote, pass configurations to environment variables so + // Spark Connect client sets them. For local remotes, they will be set + // via Py4J. + val grouped = sparkConf.getAll.toMap.grouped(10).toSeq + env.put("PYSPARK_REMOTE_INIT_CONF_LEN", grouped.length.toString) + grouped.zipWithIndex.foreach { case (group, idx) => + env.put(s"PYSPARK_REMOTE_INIT_CONF_$idx", compact(render(group))) + } + } sparkConf.getOption("spark.remote").foreach(url => env.put("SPARK_REMOTE", url)) env.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string - env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) - env.put("PYSPARK_GATEWAY_SECRET", gatewayServer.secret) + gatewayServer.foreach(s => env.put("PYSPARK_GATEWAY_PORT", s.getListeningPort.toString)) + gatewayServer.foreach(s => env.put("PYSPARK_GATEWAY_SECRET", s.secret)) // pass conf spark.pyspark.python to python process, the only way to pass info to // python process is through environment variable. sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _)) @@ -103,7 +119,7 @@ object PythonRunner { throw new SparkUserAppException(exitCode) } } finally { - gatewayServer.shutdown() + gatewayServer.foreach(_.shutdown()) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 32dd2f81bbc82..2c9ddff348056 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -43,7 +43,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S extends SparkSubmitArgumentsParser with Logging { var maybeMaster: Option[String] = None // Global defaults. These should be keep to minimum to avoid confusing behavior. - def master: String = maybeMaster.getOrElse("local[*]") + def master: String = + maybeMaster.getOrElse(System.getProperty("spark.test.master", "local[*]")) var maybeRemote: Option[String] = None var deployMode: String = null var executorMemory: String = null diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 0e19143411e96..c5646d2956aeb 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -271,6 +271,18 @@ package object config { .toSequence .createWithDefault(GarbageCollectionMetrics.OLD_GENERATION_BUILTIN_GARBAGE_COLLECTORS) + private[spark] val EVENT_LOG_INCLUDE_TASK_METRICS_ACCUMULATORS = + ConfigBuilder("spark.eventLog.includeTaskMetricsAccumulators") + .doc("Whether to include TaskMetrics' underlying accumulator values in the event log (as " + + "part of the Task/Stage/Job metrics' 'Accumulables' fields. This configuration defaults " + + "to false because the TaskMetrics values are already logged in the 'Task Metrics' " + + "fields (so the accumulator updates are redundant). This flag exists only as a " + + "backwards-compatibility escape hatch for applications that might rely on the old " + + "behavior. See SPARK-42204 for details.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + private[spark] val EVENT_LOG_OVERWRITE = ConfigBuilder("spark.eventLog.overwrite") .version("1.0.0") @@ -1374,7 +1386,6 @@ package object config { private[spark] val SHUFFLE_ACCURATE_BLOCK_SKEWED_FACTOR = ConfigBuilder("spark.shuffle.accurateBlockSkewedFactor") - .internal() .doc("A shuffle block is considered as skewed and will be accurately recorded in " + "HighlyCompressedMapStatus if its size is larger than this factor multiplying " + "the median shuffle block size or SHUFFLE_ACCURATE_BLOCK_THRESHOLD. It is " + diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushReporter.scala b/core/src/main/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushReporter.scala deleted file mode 100644 index bab7023ecdf11..0000000000000 --- a/core/src/main/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushReporter.scala +++ /dev/null @@ -1,355 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.metrics.sink.opentelemetry - -import java.nio.file.{Files, Paths} -import java.util.{Locale, SortedMap} -import java.util.concurrent.TimeUnit - -import com.codahale.metrics._ -import io.opentelemetry.api.common.Attributes -import io.opentelemetry.api.metrics.{DoubleGauge, DoubleHistogram, LongCounter} -import io.opentelemetry.exporter.otlp.metrics.OtlpGrpcMetricExporter -import io.opentelemetry.sdk.OpenTelemetrySdk -import io.opentelemetry.sdk.metrics.SdkMeterProvider -import io.opentelemetry.sdk.metrics.export.PeriodicMetricReader -import io.opentelemetry.sdk.resources.Resource - -private[spark] class OpenTelemetryPushReporter( - registry: MetricRegistry, - pollInterval: Int = 10, - pollUnit: TimeUnit = TimeUnit.SECONDS, - endpoint: String = "http://localhost:4317", - headersMap: Map[String, String] = Map(), - attributesMap: Map[String, String] = Map(), - trustedCertificatesPath: String, - privateKeyPemPath: String, - certificatePemPath: String) - extends ScheduledReporter ( - registry, - "opentelemetry-push-reporter", - MetricFilter.ALL, - TimeUnit.SECONDS, - TimeUnit.MILLISECONDS) - with MetricRegistryListener { - - val FIFTEEN_MINUTE_RATE = "_fifteen_minute_rate" - val FIVE_MINUTE_RATE = "_five_minute_rate" - val ONE_MINUTE_RATE = "_one_minute_rate" - val MEAN_RATE = "_mean_rate" - val METER = "_meter" - val TIMER = "_timer" - val COUNT = "_count" - val MAX = "_max" - val MIN = "_min" - val MEAN = "_mean" - val MEDIAN = "_50_percentile" - val SEVENTY_FIFTH_PERCENTILE = "_75_percentile" - val NINETY_FIFTH_PERCENTILE = "_95_percentile" - val NINETY_EIGHTH_PERCENTILE = "_98_percentile" - val NINETY_NINTH_PERCENTILE = "_99_percentile" - val NINE_HUNDRED_NINETY_NINTH_PERCENTILE = "_999_percentile" - val STD_DEV = "_std_dev" - - val otlpGrpcMetricExporterBuilder = OtlpGrpcMetricExporter.builder() - - for ((key, value) <- headersMap) { - otlpGrpcMetricExporterBuilder.addHeader(key, value) - } - - if (trustedCertificatesPath != null) { - otlpGrpcMetricExporterBuilder - .setTrustedCertificates(Files.readAllBytes(Paths.get(trustedCertificatesPath))) - } - - if (privateKeyPemPath != null && certificatePemPath != null) { - otlpGrpcMetricExporterBuilder - .setClientTls( - Files.readAllBytes(Paths.get(privateKeyPemPath)), - Files.readAllBytes(Paths.get(certificatePemPath))) - } - - otlpGrpcMetricExporterBuilder.setEndpoint(endpoint) - - val arrtributesBuilder = Attributes.builder() - for ((key, value) <- attributesMap) { - arrtributesBuilder.put(key, value) - } - - val resource = Resource - .getDefault() - .merge(Resource.create(arrtributesBuilder.build())); - - val metricReader = PeriodicMetricReader - .builder(otlpGrpcMetricExporterBuilder.build()) - .setInterval(pollInterval, pollUnit) - .build() - - val sdkMeterProvider: SdkMeterProvider = SdkMeterProvider - .builder() - .registerMetricReader(metricReader) - .setResource(resource) - .build() - - val openTelemetryCounters = collection.mutable.Map[String, LongCounter]() - val openTelemetryHistograms = collection.mutable.Map[String, DoubleHistogram]() - val openTelemetryGauges = collection.mutable.Map[String, DoubleGauge]() - val codahaleCounters = collection.mutable.Map[String, Counter]() - val openTelemetry = OpenTelemetrySdk - .builder() - .setMeterProvider(sdkMeterProvider) - .build(); - val openTelemetryMeter = openTelemetry.getMeter("apache-spark") - - override def report( - gauges: SortedMap[String, Gauge[_]], - counters: SortedMap[String, Counter], - histograms: SortedMap[String, Histogram], - meters: SortedMap[String, Meter], - timers: SortedMap[String, Timer]): Unit = { - counters.forEach(this.reportCounter) - gauges.forEach(this.reportGauges) - histograms.forEach(this.reportHistograms) - meters.forEach(this.reportMeters) - timers.forEach(this.reportTimers) - sdkMeterProvider.forceFlush - } - - override def onGaugeAdded(name: String, gauge: Gauge[_]): Unit = { - val metricName = normalizeMetricName(name) - generateGauge(metricName) - } - - override def onGaugeRemoved(name: String): Unit = { - val metricName = normalizeMetricName(name) - openTelemetryGauges.remove(metricName) - } - - override def onCounterAdded(name: String, counter: Counter): Unit = { - val metricName = normalizeMetricName(name) - val addedOpenTelemetryCounter = - openTelemetryMeter.counterBuilder(normalizeMetricName(metricName)).build - openTelemetryCounters.put(metricName, addedOpenTelemetryCounter) - codahaleCounters.put(metricName, registry.counter(metricName)) - } - - override def onCounterRemoved(name: String): Unit = { - val metricName = normalizeMetricName(name) - openTelemetryCounters.remove(metricName) - codahaleCounters.remove(metricName) - } - - override def onHistogramAdded(name: String, histogram: Histogram): Unit = { - val metricName = normalizeMetricName(name) - generateHistogramGroup(metricName) - } - - override def onHistogramRemoved(name: String): Unit = { - val metricName = normalizeMetricName(name) - cleanHistogramGroup(metricName) - } - - override def onMeterAdded(name: String, meter: Meter): Unit = { - val metricName = normalizeMetricName(name) + METER - generateGauge(metricName + COUNT) - generateGauge(metricName + MEAN_RATE) - generateGauge(metricName + FIFTEEN_MINUTE_RATE) - generateGauge(metricName + FIVE_MINUTE_RATE) - generateGauge(metricName + ONE_MINUTE_RATE) - } - - override def onMeterRemoved(name: String): Unit = { - val metricName = normalizeMetricName(name) + METER - openTelemetryGauges.remove(metricName + COUNT) - openTelemetryGauges.remove(metricName + MEAN_RATE) - openTelemetryGauges.remove(metricName + ONE_MINUTE_RATE) - openTelemetryGauges.remove(metricName + FIVE_MINUTE_RATE) - openTelemetryGauges.remove(metricName + FIFTEEN_MINUTE_RATE) - } - - override def onTimerAdded(name: String, timer: Timer): Unit = { - val metricName = normalizeMetricName(name) + TIMER - generateHistogramGroup(metricName) - generateAdditionalHistogramGroupForTimers(metricName) - } - - override def onTimerRemoved(name: String): Unit = { - val metricName = normalizeMetricName(name) + TIMER - cleanHistogramGroup(name) - cleanAdditionalHistogramGroupTimers(metricName) - } - - override def stop(): Unit = { - super.stop() - sdkMeterProvider.close() - } - - private def normalizeMetricName(name: String): String = { - name.toLowerCase(Locale.ROOT).replaceAll("[^a-z0-9]", "_") - } - - private def generateHistogram(metricName: String): Unit = { - val openTelemetryHistogram = - openTelemetryMeter.histogramBuilder(metricName).build - openTelemetryHistograms.put(metricName, openTelemetryHistogram) - } - - private def generateHistogramGroup(metricName: String): Unit = { - generateHistogram(metricName + COUNT) - generateHistogram(metricName + MAX) - generateHistogram(metricName + MIN) - generateHistogram(metricName + MEAN) - generateHistogram(metricName + MEDIAN) - generateHistogram(metricName + STD_DEV) - generateHistogram(metricName + SEVENTY_FIFTH_PERCENTILE) - generateHistogram(metricName + NINETY_FIFTH_PERCENTILE) - generateHistogram(metricName + NINETY_EIGHTH_PERCENTILE) - generateHistogram(metricName + NINETY_NINTH_PERCENTILE) - generateHistogram(metricName + NINE_HUNDRED_NINETY_NINTH_PERCENTILE) - } - - private def generateAdditionalHistogramGroupForTimers(metricName: String): Unit = { - generateHistogram(metricName + FIFTEEN_MINUTE_RATE) - generateHistogram(metricName + FIVE_MINUTE_RATE) - generateHistogram(metricName + ONE_MINUTE_RATE) - generateHistogram(metricName + MEAN_RATE) - } - - private def cleanHistogramGroup(metricName: String): Unit = { - openTelemetryHistograms.remove(metricName + COUNT) - openTelemetryHistograms.remove(metricName + MAX) - openTelemetryHistograms.remove(metricName + MIN) - openTelemetryHistograms.remove(metricName + MEAN) - openTelemetryHistograms.remove(metricName + MEDIAN) - openTelemetryHistograms.remove(metricName + STD_DEV) - openTelemetryHistograms.remove(metricName + SEVENTY_FIFTH_PERCENTILE) - openTelemetryHistograms.remove(metricName + NINETY_FIFTH_PERCENTILE) - openTelemetryHistograms.remove(metricName + NINETY_EIGHTH_PERCENTILE) - openTelemetryHistograms.remove(metricName + NINETY_NINTH_PERCENTILE) - openTelemetryHistograms.remove(metricName + NINE_HUNDRED_NINETY_NINTH_PERCENTILE) - } - - private def cleanAdditionalHistogramGroupTimers(metricName: String): Unit = { - openTelemetryHistograms.remove(metricName + FIFTEEN_MINUTE_RATE) - openTelemetryHistograms.remove(metricName + FIVE_MINUTE_RATE) - openTelemetryHistograms.remove(metricName + ONE_MINUTE_RATE) - openTelemetryHistograms.remove(metricName + MEAN_RATE) - } - - private def generateGauge(metricName: String): Unit = { - val addedOpenTelemetryGauge = - openTelemetryMeter.gaugeBuilder(normalizeMetricName(metricName)).build - openTelemetryGauges.put(metricName, addedOpenTelemetryGauge) - } - - private def reportCounter(name: String, counter: Counter): Unit = { - val metricName = normalizeMetricName(name) - val openTelemetryCounter = openTelemetryCounters(metricName) - val codahaleCounter = codahaleCounters(metricName) - val diff = counter.getCount - codahaleCounter.getCount - openTelemetryCounter.add(diff) - codahaleCounter.inc(diff) - } - - private def reportGauges(name: String, gauge: Gauge[_]): Unit = { - val metricName = normalizeMetricName(name) - gauge.getValue match { - case d: Double => - openTelemetryGauges(metricName).set(d.doubleValue) - case d: Long => - openTelemetryGauges(metricName).set(d.doubleValue) - case d: Int => - openTelemetryGauges(metricName).set(d.doubleValue) - case _ => () - } - } - - private def reportHistograms(name: String, histogram: Histogram): Unit = { - val metricName = normalizeMetricName(name) - reportHistogramGroup(metricName, histogram) - } - - private def reportMeters(name: String, meter: Meter): Unit = { - val metricName = normalizeMetricName(name) + METER - val openTelemetryGaugeCount = openTelemetryGauges(metricName + COUNT) - openTelemetryGaugeCount.set(meter.getCount.toDouble) - val openTelemetryGauge0neMinuteRate = openTelemetryGauges(metricName + ONE_MINUTE_RATE) - openTelemetryGauge0neMinuteRate.set(meter.getOneMinuteRate) - val openTelemetryGaugeFiveMinuteRate = openTelemetryGauges(metricName + FIVE_MINUTE_RATE) - openTelemetryGaugeFiveMinuteRate.set(meter.getFiveMinuteRate) - val openTelemetryGaugeFifteenMinuteRate = openTelemetryGauges( - metricName + FIFTEEN_MINUTE_RATE) - openTelemetryGaugeFifteenMinuteRate.set(meter.getFifteenMinuteRate) - val openTelemetryGaugeMeanRate = openTelemetryGauges(metricName + MEAN_RATE) - openTelemetryGaugeMeanRate.set(meter.getMeanRate) - } - - private def reportTimers(name: String, timer: Timer): Unit = { - val metricName = normalizeMetricName(name) + TIMER - val openTelemetryHistogramMax = openTelemetryHistograms(metricName + MAX) - openTelemetryHistogramMax.record(timer.getCount.toDouble) - val openTelemetryHistogram0neMinuteRate = openTelemetryHistograms( - metricName + ONE_MINUTE_RATE) - openTelemetryHistogram0neMinuteRate.record(timer.getOneMinuteRate) - val openTelemetryHistogramFiveMinuteRate = openTelemetryHistograms( - metricName + FIVE_MINUTE_RATE) - openTelemetryHistogramFiveMinuteRate.record(timer.getFiveMinuteRate) - val openTelemetryHistogramFifteenMinuteRate = openTelemetryHistograms( - metricName + FIFTEEN_MINUTE_RATE) - openTelemetryHistogramFifteenMinuteRate.record(timer.getFifteenMinuteRate) - val openTelemetryHistogramMeanRate = openTelemetryHistograms(metricName + MEAN_RATE) - openTelemetryHistogramMeanRate.record(timer.getMeanRate) - val snapshot = timer.getSnapshot - reportHistogramGroup(metricName, snapshot) - } - - private def reportHistogramGroup(metricName: String, histogram: Histogram): Unit = { - val openTelemetryHistogramCount = openTelemetryHistograms(metricName + COUNT) - openTelemetryHistogramCount.record(histogram.getCount.toDouble) - val snapshot = histogram.getSnapshot - reportHistogramGroup(metricName, snapshot) - } - - private def reportHistogramGroup(metricName: String, snapshot: Snapshot): Unit = { - val openTelemetryHistogramMax = openTelemetryHistograms(metricName + MAX) - openTelemetryHistogramMax.record(snapshot.getMax.toDouble) - val openTelemetryHistogramMin = openTelemetryHistograms(metricName + MIN) - openTelemetryHistogramMin.record(snapshot.getMin.toDouble) - val openTelemetryHistogramMean = openTelemetryHistograms(metricName + MEAN) - openTelemetryHistogramMean.record(snapshot.getMean) - val openTelemetryHistogramMedian = openTelemetryHistograms(metricName + MEDIAN) - openTelemetryHistogramMedian.record(snapshot.getMedian) - val openTelemetryHistogramStdDev = openTelemetryHistograms(metricName + STD_DEV) - openTelemetryHistogramStdDev.record(snapshot.getStdDev) - val openTelemetryHistogram75Percentile = openTelemetryHistograms( - metricName + SEVENTY_FIFTH_PERCENTILE) - openTelemetryHistogram75Percentile.record(snapshot.get75thPercentile) - val openTelemetryHistogram95Percentile = openTelemetryHistograms( - metricName + NINETY_FIFTH_PERCENTILE) - openTelemetryHistogram95Percentile.record(snapshot.get95thPercentile) - val openTelemetryHistogram98Percentile = openTelemetryHistograms( - metricName + NINETY_EIGHTH_PERCENTILE) - openTelemetryHistogram98Percentile.record(snapshot.get98thPercentile) - val openTelemetryHistogram99Percentile = openTelemetryHistograms( - metricName + NINETY_NINTH_PERCENTILE) - openTelemetryHistogram99Percentile.record(snapshot.get99thPercentile) - val openTelemetryHistogram999Percentile = openTelemetryHistograms( - metricName + NINE_HUNDRED_NINETY_NINTH_PERCENTILE) - openTelemetryHistogram999Percentile.record(snapshot.get999thPercentile) - } -} diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushSink.scala deleted file mode 100644 index 23d047f585efc..0000000000000 --- a/core/src/main/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushSink.scala +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.metrics.sink.opentelemetry - -import java.util.{Locale, Properties} -import java.util.concurrent.TimeUnit - -import com.codahale.metrics.MetricRegistry -import org.apache.commons.lang3.StringUtils - -import org.apache.spark.internal.Logging -import org.apache.spark.metrics.sink.Sink - -private[spark] object OpenTelemetryPushSink { - private def fetchMapFromProperties( - properties: Properties, - keyPrefix: String): Map[String, String] = { - val propertiesMap = scala.collection.mutable.Map[String, String]() - val valueEnumeration = properties.propertyNames - val dotCount = keyPrefix.count(_ == '.') - while (valueEnumeration.hasMoreElements) { - val key = valueEnumeration.nextElement.asInstanceOf[String] - if (key.startsWith(keyPrefix)) { - val dotIndex = StringUtils.ordinalIndexOf(key, ".", dotCount + 1) - val mapKey = key.substring(dotIndex + 1) - propertiesMap(mapKey) = properties.getProperty(key) - } - } - propertiesMap.toMap - } -} - -private[spark] class OpenTelemetryPushSink(val property: Properties, val registry: MetricRegistry) - extends Sink with Logging { - - val OPEN_TELEMETRY_KEY_PERIOD = "period" - val OPEN_TELEMETRY_KEY_UNIT = "unit" - val OPEN_TELEMETRY_DEFAULT_PERIOD = "10" - val OPEN_TELEMETRY_DEFAULT_UNIT = "SECONDS" - val OPEN_TELEMETRY_KEY_ENDPOINT = "endpoint" - val GRPC_METRIC_EXPORTER_HEADER_KEY = "grpc.metric.exporter.header" - val GRPC_METRIC_EXPORTER_ATTRIBUTES_KEY = "grpc.metric.exporter.attributes" - val TRUSTED_CERTIFICATE_PATH = "trusted.certificate.path" - val PRIVATE_KEY_PEM_PATH = "private.key.pem.path" - val CERTIFICATE_PEM_PATH = "certificate.pem.path" - - val pollPeriod = property - .getProperty(OPEN_TELEMETRY_KEY_PERIOD, OPEN_TELEMETRY_DEFAULT_PERIOD) - .toInt - - val pollUnit = TimeUnit.valueOf( - property - .getProperty(OPEN_TELEMETRY_KEY_UNIT, OPEN_TELEMETRY_DEFAULT_UNIT) - .toUpperCase(Locale.ROOT)) - - val endpoint = property.getProperty(OPEN_TELEMETRY_KEY_ENDPOINT) - - val headersMap = - OpenTelemetryPushSink.fetchMapFromProperties(property, GRPC_METRIC_EXPORTER_HEADER_KEY) - val attributesMap = - OpenTelemetryPushSink.fetchMapFromProperties(property, GRPC_METRIC_EXPORTER_ATTRIBUTES_KEY) - - val trustedCertificatesPath: String = - property.getProperty(TRUSTED_CERTIFICATE_PATH) - - val privateKeyPemPath: String = property.getProperty(PRIVATE_KEY_PEM_PATH) - - val certificatePemPath: String = property.getProperty(CERTIFICATE_PEM_PATH) - - val reporter = new OpenTelemetryPushReporter( - registry, - pollInterval = pollPeriod, - pollUnit, - endpoint, - headersMap, - attributesMap, - trustedCertificatesPath, - privateKeyPemPath, - certificatePemPath) - - registry.addListener(reporter) - - override def start(): Unit = { - reporter.start(pollPeriod, pollUnit) - } - - override def stop(): Unit = { - reporter.stop() - } - - override def report(): Unit = { - reporter.report() - } -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 6c824e2fdeaed..2c89fe7885d08 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -27,6 +27,7 @@ import scala.annotation.tailrec import scala.collection.Map import scala.collection.mutable import scala.collection.mutable.{HashMap, HashSet, ListBuffer} +import scala.concurrent.Promise import scala.concurrent.duration._ import scala.util.control.NonFatal @@ -1116,11 +1117,18 @@ private[spark] class DAGScheduler( /** * Cancel all jobs with a given tag. + * + * @param tag The tag to be cancelled. Cannot contain ',' (comma) character. + * @param reason reason for cancellation. + * @param cancelledJobs a promise to be completed with operation IDs being cancelled. */ - def cancelJobsWithTag(tag: String, reason: Option[String]): Unit = { + def cancelJobsWithTag( + tag: String, + reason: Option[String], + cancelledJobs: Option[Promise[Seq[ActiveJob]]]): Unit = { SparkContext.throwIfInvalidTag(tag) logInfo(log"Asked to cancel jobs with tag ${MDC(TAG, tag)}") - eventProcessLoop.post(JobTagCancelled(tag, reason)) + eventProcessLoop.post(JobTagCancelled(tag, reason, cancelledJobs)) } /** @@ -1234,17 +1242,22 @@ private[spark] class DAGScheduler( jobIds.foreach(handleJobCancellation(_, Option(updatedReason))) } - private[scheduler] def handleJobTagCancelled(tag: String, reason: Option[String]): Unit = { - // Cancel all jobs belonging that have this tag. + private[scheduler] def handleJobTagCancelled( + tag: String, + reason: Option[String], + cancelledJobs: Option[Promise[Seq[ActiveJob]]]): Unit = { + // Cancel all jobs that have all provided tags. // First finds all active jobs with this group id, and then kill stages for them. - val jobIds = activeJobs.filter { activeJob => + val jobsToBeCancelled = activeJobs.filter { activeJob => Option(activeJob.properties).exists { properties => Option(properties.getProperty(SparkContext.SPARK_JOB_TAGS)).getOrElse("") .split(SparkContext.SPARK_JOB_TAGS_SEP).filter(!_.isEmpty).toSet.contains(tag) } - }.map(_.jobId) - val updatedReason = reason.getOrElse("part of cancelled job tag %s".format(tag)) - jobIds.foreach(handleJobCancellation(_, Option(updatedReason))) + } + val updatedReason = + reason.getOrElse("part of cancelled job tags %s".format(tag)) + jobsToBeCancelled.map(_.jobId).foreach(handleJobCancellation(_, Option(updatedReason))) + cancelledJobs.map(_.success(jobsToBeCancelled.toSeq)) } private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = { @@ -3113,8 +3126,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case JobGroupCancelled(groupId, cancelFutureJobs, reason) => dagScheduler.handleJobGroupCancelled(groupId, cancelFutureJobs, reason) - case JobTagCancelled(tag, reason) => - dagScheduler.handleJobTagCancelled(tag, reason) + case JobTagCancelled(tag, reason, cancelledJobs) => + dagScheduler.handleJobTagCancelled(tag, reason, cancelledJobs) case AllJobsCancelled => dagScheduler.doCancelAllJobs() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index c9ad54d1fdc7e..8932d2ef323ba 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -19,6 +19,8 @@ package org.apache.spark.scheduler import java.util.Properties +import scala.concurrent.Promise + import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.util.{AccumulatorV2, CallSite} @@ -71,7 +73,8 @@ private[scheduler] case class JobGroupCancelled( private[scheduler] case class JobTagCancelled( tagName: String, - reason: Option[String]) extends DAGSchedulerEvent + reason: Option[String], + cancelledJobs: Option[Promise[Seq[ActiveJob]]]) extends DAGSchedulerEvent private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index efd8fecb974e8..1e46142fab255 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -31,7 +31,7 @@ import org.apache.spark.deploy.history.EventLogFileWriter import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ -import org.apache.spark.util.{JsonProtocol, Utils} +import org.apache.spark.util.{JsonProtocol, JsonProtocolOptions, Utils} /** * A SparkListener that logs events to persistent storage. @@ -74,6 +74,8 @@ private[spark] class EventLoggingListener( private val liveStageExecutorMetrics = mutable.HashMap.empty[(Int, Int), mutable.HashMap[String, ExecutorMetrics]] + private[this] val jsonProtocolOptions = new JsonProtocolOptions(sparkConf) + /** * Creates the log file in the configured log directory. */ @@ -84,7 +86,7 @@ private[spark] class EventLoggingListener( private def initEventLog(): Unit = { val metadata = SparkListenerLogStart(SPARK_VERSION) - val eventJson = JsonProtocol.sparkEventToJsonString(metadata) + val eventJson = JsonProtocol.sparkEventToJsonString(metadata, jsonProtocolOptions) logWriter.writeEvent(eventJson, flushLogger = true) if (testing && loggedEvents != null) { loggedEvents += eventJson @@ -93,7 +95,7 @@ private[spark] class EventLoggingListener( /** Log the event as JSON. */ private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false): Unit = { - val eventJson = JsonProtocol.sparkEventToJsonString(event) + val eventJson = JsonProtocol.sparkEventToJsonString(event, jsonProtocolOptions) logWriter.writeEvent(eventJson, flushLogger) if (testing) { loggedEvents += eventJson diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index cc19b71bfc4d6..384f939a843bc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -22,8 +22,6 @@ import javax.annotation.Nullable import scala.collection.Map -import com.fasterxml.jackson.annotation.JsonTypeInfo - import org.apache.spark.TaskEndReason import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} @@ -31,13 +29,6 @@ import org.apache.spark.resource.ResourceProfile import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} -@DeveloperApi -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") -trait SparkListenerEvent { - /* Whether output this event to the event log */ - protected[spark] def logEvent: Boolean = true -} - @DeveloperApi case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null) extends SparkListenerEvent diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala index 686ac1eb786e0..f29e8778da037 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala @@ -60,7 +60,13 @@ class BlockManagerStorageEndpoint( if (mapOutputTracker != null) { mapOutputTracker.unregisterShuffle(shuffleId) } - SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId) + val shuffleManager = SparkEnv.get.shuffleManager + if (shuffleManager != null) { + shuffleManager.unregisterShuffle(shuffleId) + } else { + logDebug(log"Ignore remove shuffle ${MDC(SHUFFLE_ID, shuffleId)}") + true + } } case DecommissionBlockManager => diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 1498b224b0c92..3e57094b36a7e 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -35,6 +35,7 @@ import org.apache.spark.internal.LogKeys._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.util.{AbstractFileRegion, JavaUtils} import org.apache.spark.security.CryptoStreamUtils +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBuffer @@ -324,7 +325,7 @@ private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: private var _transferred = 0L - private val buffer = ByteBuffer.allocateDirect(64 * 1024) + private val buffer = Platform.allocateDirectBuffer(64 * 1024) buffer.flip() override def count(): Long = blockSize diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index c3dc459b0f88e..fff6ec4f5b170 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -449,16 +449,24 @@ private[spark] object UIUtils extends Logging { val startRatio = if (total == 0) 0.0 else (boundedStarted.toDouble / total) * 100 val startWidth = "width: %s%%".format(startRatio) + val killTaskReasonText = reasonToNumKilled.toSeq.sortBy(-_._2).map { + case (reason, count) => s" ($count killed: $reason)" + }.mkString + val progressTitle = s"$completed/$total" + { + if (started > 0) s" ($started running)" else "" + } + { + if (failed > 0) s" ($failed failed)" else "" + } + { + if (skipped > 0) s" ($skipped skipped)" else "" + } + killTaskReasonText +
- + {completed}/{total} { if (failed == 0 && skipped == 0 && started > 0) s"($started running)" } { if (failed > 0) s"($failed failed)" } { if (skipped > 0) s"($skipped skipped)" } - { reasonToNumKilled.toSeq.sortBy(-_._2).map { - case (reason, count) => s"($count killed: $reason)" - } - } + { killTaskReasonText }
diff --git a/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala index a4145bb36acc9..1683e892511f9 100644 --- a/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala @@ -57,7 +57,7 @@ private[spark] class DirectByteBufferOutputStream(capacity: Int) extends OutputS if (newCapacity < minCapacity) newCapacity = minCapacity val oldBuffer = buffer oldBuffer.flip() - val newBuffer = ByteBuffer.allocateDirect(newCapacity) + val newBuffer = Platform.allocateDirectBuffer(newCapacity) newBuffer.put(oldBuffer) StorageUtils.dispose(oldBuffer) buffer = newBuffer diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 19cefbc0479a9..e30380f41566a 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -28,6 +28,7 @@ import org.json4s.jackson.JsonMethods.compact import org.apache.spark._ import org.apache.spark.executor._ +import org.apache.spark.internal.config._ import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.rdd.{DeterministicLevel, RDDOperationScope} import org.apache.spark.resource.{ExecutorResourceRequest, ResourceInformation, ResourceProfile, TaskResourceRequest} @@ -37,6 +38,16 @@ import org.apache.spark.storage._ import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils.weakIntern +/** + * Helper class for passing configuration options to JsonProtocol. + * We use this instead of passing SparkConf directly because it lets us avoid + * repeated re-parsing of configuration values on each read. + */ +private[spark] class JsonProtocolOptions(conf: SparkConf) { + val includeTaskMetricsAccumulators: Boolean = + conf.get(EVENT_LOG_INCLUDE_TASK_METRICS_ACCUMULATORS) +} + /** * Serializes SparkListener events to/from JSON. This protocol provides strong backwards- * and forwards-compatibility guarantees: any version of Spark should be able to read JSON output @@ -55,30 +66,41 @@ import org.apache.spark.util.Utils.weakIntern private[spark] object JsonProtocol extends JsonUtils { // TODO: Remove this file and put JSON serialization into each individual class. + private[util] + val defaultOptions: JsonProtocolOptions = new JsonProtocolOptions(new SparkConf(false)) + /** ------------------------------------------------- * * JSON serialization methods for SparkListenerEvents | * -------------------------------------------------- */ + // Only for use in tests. Production code should use the two-argument overload defined below. def sparkEventToJsonString(event: SparkListenerEvent): String = { + sparkEventToJsonString(event, defaultOptions) + } + + def sparkEventToJsonString(event: SparkListenerEvent, options: JsonProtocolOptions): String = { toJsonString { generator => - writeSparkEventToJson(event, generator) + writeSparkEventToJson(event, generator, options) } } - def writeSparkEventToJson(event: SparkListenerEvent, g: JsonGenerator): Unit = { + def writeSparkEventToJson( + event: SparkListenerEvent, + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { event match { case stageSubmitted: SparkListenerStageSubmitted => - stageSubmittedToJson(stageSubmitted, g) + stageSubmittedToJson(stageSubmitted, g, options) case stageCompleted: SparkListenerStageCompleted => - stageCompletedToJson(stageCompleted, g) + stageCompletedToJson(stageCompleted, g, options) case taskStart: SparkListenerTaskStart => - taskStartToJson(taskStart, g) + taskStartToJson(taskStart, g, options) case taskGettingResult: SparkListenerTaskGettingResult => - taskGettingResultToJson(taskGettingResult, g) + taskGettingResultToJson(taskGettingResult, g, options) case taskEnd: SparkListenerTaskEnd => - taskEndToJson(taskEnd, g) + taskEndToJson(taskEnd, g, options) case jobStart: SparkListenerJobStart => - jobStartToJson(jobStart, g) + jobStartToJson(jobStart, g, options) case jobEnd: SparkListenerJobEnd => jobEndToJson(jobEnd, g) case environmentUpdate: SparkListenerEnvironmentUpdate => @@ -112,12 +134,15 @@ private[spark] object JsonProtocol extends JsonUtils { } } - def stageSubmittedToJson(stageSubmitted: SparkListenerStageSubmitted, g: JsonGenerator): Unit = { + def stageSubmittedToJson( + stageSubmitted: SparkListenerStageSubmitted, + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { g.writeStartObject() g.writeStringField("Event", SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.stageSubmitted) g.writeFieldName("Stage Info") // SPARK-42205: don't log accumulables in start events: - stageInfoToJson(stageSubmitted.stageInfo, g, includeAccumulables = false) + stageInfoToJson(stageSubmitted.stageInfo, g, options, includeAccumulables = false) Option(stageSubmitted.properties).foreach { properties => g.writeFieldName("Properties") propertiesToJson(properties, g) @@ -125,38 +150,48 @@ private[spark] object JsonProtocol extends JsonUtils { g.writeEndObject() } - def stageCompletedToJson(stageCompleted: SparkListenerStageCompleted, g: JsonGenerator): Unit = { + def stageCompletedToJson( + stageCompleted: SparkListenerStageCompleted, + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { g.writeStartObject() g.writeStringField("Event", SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.stageCompleted) g.writeFieldName("Stage Info") - stageInfoToJson(stageCompleted.stageInfo, g, includeAccumulables = true) + stageInfoToJson(stageCompleted.stageInfo, g, options, includeAccumulables = true) g.writeEndObject() } - def taskStartToJson(taskStart: SparkListenerTaskStart, g: JsonGenerator): Unit = { + def taskStartToJson( + taskStart: SparkListenerTaskStart, + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { g.writeStartObject() g.writeStringField("Event", SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskStart) g.writeNumberField("Stage ID", taskStart.stageId) g.writeNumberField("Stage Attempt ID", taskStart.stageAttemptId) g.writeFieldName("Task Info") // SPARK-42205: don't log accumulables in start events: - taskInfoToJson(taskStart.taskInfo, g, includeAccumulables = false) + taskInfoToJson(taskStart.taskInfo, g, options, includeAccumulables = false) g.writeEndObject() } def taskGettingResultToJson( taskGettingResult: SparkListenerTaskGettingResult, - g: JsonGenerator): Unit = { + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { val taskInfo = taskGettingResult.taskInfo g.writeStartObject() g.writeStringField("Event", SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskGettingResult) g.writeFieldName("Task Info") // SPARK-42205: don't log accumulables in "task getting result" events: - taskInfoToJson(taskInfo, g, includeAccumulables = false) + taskInfoToJson(taskInfo, g, options, includeAccumulables = false) g.writeEndObject() } - def taskEndToJson(taskEnd: SparkListenerTaskEnd, g: JsonGenerator): Unit = { + def taskEndToJson( + taskEnd: SparkListenerTaskEnd, + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { g.writeStartObject() g.writeStringField("Event", SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskEnd) g.writeNumberField("Stage ID", taskEnd.stageId) @@ -165,7 +200,7 @@ private[spark] object JsonProtocol extends JsonUtils { g.writeFieldName("Task End Reason") taskEndReasonToJson(taskEnd.reason, g) g.writeFieldName("Task Info") - taskInfoToJson(taskEnd.taskInfo, g, includeAccumulables = true) + taskInfoToJson(taskEnd.taskInfo, g, options, includeAccumulables = true) g.writeFieldName("Task Executor Metrics") executorMetricsToJson(taskEnd.taskExecutorMetrics, g) Option(taskEnd.taskMetrics).foreach { m => @@ -175,7 +210,10 @@ private[spark] object JsonProtocol extends JsonUtils { g.writeEndObject() } - def jobStartToJson(jobStart: SparkListenerJobStart, g: JsonGenerator): Unit = { + def jobStartToJson( + jobStart: SparkListenerJobStart, + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { g.writeStartObject() g.writeStringField("Event", SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.jobStart) g.writeNumberField("Job ID", jobStart.jobId) @@ -186,7 +224,7 @@ private[spark] object JsonProtocol extends JsonUtils { // the job was submitted: it is technically possible for a stage to belong to multiple // concurrent jobs, so this situation can arise even without races occurring between // event logging and stage completion. - jobStart.stageInfos.foreach(stageInfoToJson(_, g, includeAccumulables = true)) + jobStart.stageInfos.foreach(stageInfoToJson(_, g, options, includeAccumulables = true)) g.writeEndArray() g.writeArrayFieldStart("Stage IDs") jobStart.stageIds.foreach(g.writeNumber) @@ -386,6 +424,7 @@ private[spark] object JsonProtocol extends JsonUtils { def stageInfoToJson( stageInfo: StageInfo, g: JsonGenerator, + options: JsonProtocolOptions, includeAccumulables: Boolean): Unit = { g.writeStartObject() g.writeNumberField("Stage ID", stageInfo.stageId) @@ -404,7 +443,10 @@ private[spark] object JsonProtocol extends JsonUtils { stageInfo.failureReason.foreach(g.writeStringField("Failure Reason", _)) g.writeFieldName("Accumulables") if (includeAccumulables) { - accumulablesToJson(stageInfo.accumulables.values, g) + accumulablesToJson( + stageInfo.accumulables.values, + g, + includeTaskMetricsAccumulators = options.includeTaskMetricsAccumulators) } else { g.writeStartArray() g.writeEndArray() @@ -418,6 +460,7 @@ private[spark] object JsonProtocol extends JsonUtils { def taskInfoToJson( taskInfo: TaskInfo, g: JsonGenerator, + options: JsonProtocolOptions, includeAccumulables: Boolean): Unit = { g.writeStartObject() g.writeNumberField("Task ID", taskInfo.taskId) @@ -435,7 +478,10 @@ private[spark] object JsonProtocol extends JsonUtils { g.writeBooleanField("Killed", taskInfo.killed) g.writeFieldName("Accumulables") if (includeAccumulables) { - accumulablesToJson(taskInfo.accumulables, g) + accumulablesToJson( + taskInfo.accumulables, + g, + includeTaskMetricsAccumulators = options.includeTaskMetricsAccumulators) } else { g.writeStartArray() g.writeEndArray() @@ -443,13 +489,23 @@ private[spark] object JsonProtocol extends JsonUtils { g.writeEndObject() } - private lazy val accumulableExcludeList = Set("internal.metrics.updatedBlockStatuses") + private[util] val accumulableExcludeList = Set(InternalAccumulator.UPDATED_BLOCK_STATUSES) + + private[this] val taskMetricAccumulableNames = TaskMetrics.empty.nameToAccums.keySet.toSet - def accumulablesToJson(accumulables: Iterable[AccumulableInfo], g: JsonGenerator): Unit = { + def accumulablesToJson( + accumulables: Iterable[AccumulableInfo], + g: JsonGenerator, + includeTaskMetricsAccumulators: Boolean = true): Unit = { g.writeStartArray() accumulables - .filterNot(_.name.exists(accumulableExcludeList.contains)) - .toList.sortBy(_.id).foreach(a => accumulableInfoToJson(a, g)) + .filterNot { acc => + acc.name.exists(accumulableExcludeList.contains) || + (!includeTaskMetricsAccumulators && acc.name.exists(taskMetricAccumulableNames.contains)) + } + .toList + .sortBy(_.id) + .foreach(a => accumulableInfoToJson(a, g)) g.writeEndArray() } diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index f0d7059e29be1..380231ce97c0b 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -208,7 +208,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft ThreadUtils.awaitReady(job, Duration.Inf).failed.foreach { case e: SparkException => checkError( exception = e, - errorClass = "SPARK_JOB_CANCELLED", + condition = "SPARK_JOB_CANCELLED", sqlState = "XXKDA", parameters = scala.collection.immutable.Map( "jobId" -> "0", @@ -222,7 +222,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft sc.setJobGroup(jobGroupName, "") sc.parallelize(1 to 100).count() }, - errorClass = "SPARK_JOB_CANCELLED", + condition = "SPARK_JOB_CANCELLED", sqlState = "XXKDA", parameters = scala.collection.immutable.Map( "jobId" -> "1", @@ -258,7 +258,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft ThreadUtils.awaitReady(job, Duration.Inf).failed.foreach { case e: SparkException => checkError( exception = e, - errorClass = "SPARK_JOB_CANCELLED", + condition = "SPARK_JOB_CANCELLED", sqlState = "XXKDA", parameters = scala.collection.immutable.Map( "jobId" -> "0", diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 7106a780b3256..22c6280198c9a 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -27,7 +27,10 @@ trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { sel def sc: SparkContext = _sc - val conf = new SparkConf(false) + // SPARK-49647: use `SparkConf()` instead of `SparkConf(false)` because we want to + // load defaults from system properties and the classpath, including default test + // settings specified in the SBT and Maven build definitions. + val conf: SparkConf = new SparkConf() /** * Initialize the [[SparkContext]]. Generally, this is just called from beforeAll; however, in diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 1966a60c1665e..9f310c06ac5ae 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -327,9 +327,9 @@ abstract class SparkFunSuite } /** - * Checks an exception with an error class against expected results. + * Checks an exception with an error condition against expected results. * @param exception The exception to check - * @param errorClass The expected error class identifying the error + * @param condition The expected error condition identifying the error * @param sqlState Optional the expected SQLSTATE, not verified if not supplied * @param parameters A map of parameter names and values. The names are as defined * in the error-classes file. @@ -338,12 +338,12 @@ abstract class SparkFunSuite */ protected def checkError( exception: SparkThrowable, - errorClass: String, + condition: String, sqlState: Option[String] = None, parameters: Map[String, String] = Map.empty, matchPVals: Boolean = false, queryContext: Array[ExpectedContext] = Array.empty): Unit = { - assert(exception.getErrorClass === errorClass) + assert(exception.getErrorClass === condition) sqlState.foreach(state => assert(exception.getSqlState === state)) val expectedParameters = exception.getMessageParameters.asScala if (matchPVals) { @@ -390,55 +390,55 @@ abstract class SparkFunSuite protected def checkError( exception: SparkThrowable, - errorClass: String, + condition: String, sqlState: String, parameters: Map[String, String]): Unit = - checkError(exception, errorClass, Some(sqlState), parameters) + checkError(exception, condition, Some(sqlState), parameters) protected def checkError( exception: SparkThrowable, - errorClass: String, + condition: String, sqlState: String, parameters: Map[String, String], context: ExpectedContext): Unit = - checkError(exception, errorClass, Some(sqlState), parameters, false, Array(context)) + checkError(exception, condition, Some(sqlState), parameters, false, Array(context)) protected def checkError( exception: SparkThrowable, - errorClass: String, + condition: String, parameters: Map[String, String], context: ExpectedContext): Unit = - checkError(exception, errorClass, None, parameters, false, Array(context)) + checkError(exception, condition, None, parameters, false, Array(context)) protected def checkError( exception: SparkThrowable, - errorClass: String, + condition: String, sqlState: String, context: ExpectedContext): Unit = - checkError(exception, errorClass, None, Map.empty, false, Array(context)) + checkError(exception, condition, Some(sqlState), Map.empty, false, Array(context)) protected def checkError( exception: SparkThrowable, - errorClass: String, + condition: String, sqlState: Option[String], parameters: Map[String, String], context: ExpectedContext): Unit = - checkError(exception, errorClass, sqlState, parameters, + checkError(exception, condition, sqlState, parameters, false, Array(context)) protected def checkErrorMatchPVals( exception: SparkThrowable, - errorClass: String, + condition: String, parameters: Map[String, String]): Unit = - checkError(exception, errorClass, None, parameters, matchPVals = true) + checkError(exception, condition, None, parameters, matchPVals = true) protected def checkErrorMatchPVals( exception: SparkThrowable, - errorClass: String, + condition: String, sqlState: Option[String], parameters: Map[String, String], context: ExpectedContext): Unit = - checkError(exception, errorClass, sqlState, parameters, + checkError(exception, condition, sqlState, parameters, matchPVals = true, Array(context)) protected def checkErrorTableNotFound( @@ -446,7 +446,7 @@ abstract class SparkFunSuite tableName: String, queryContext: ExpectedContext): Unit = checkError(exception = exception, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> tableName), queryContext = Array(queryContext)) @@ -454,13 +454,13 @@ abstract class SparkFunSuite exception: SparkThrowable, tableName: String): Unit = checkError(exception = exception, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> tableName)) protected def checkErrorTableAlreadyExists(exception: SparkThrowable, tableName: String): Unit = checkError(exception = exception, - errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> tableName)) case class ExpectedContext( diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala index d99589c171c3f..946ea75686e32 100644 --- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala @@ -43,16 +43,13 @@ class SparkThrowableSuite extends SparkFunSuite { /* Used to regenerate the error class file. Run: {{{ SPARK_GENERATE_GOLDEN_FILES=1 build/sbt \ - "core/testOnly *SparkThrowableSuite -- -t \"Error classes are correctly formatted\"" + "core/testOnly *SparkThrowableSuite -- -t \"Error conditions are correctly formatted\"" }}} */ private val regenerateCommand = "SPARK_GENERATE_GOLDEN_FILES=1 build/sbt " + "\"core/testOnly *SparkThrowableSuite -- -t \\\"Error classes match with document\\\"\"" private val errorJsonFilePath = getWorkspaceFilePath( - // Note that though we call them "error classes" here, the proper name is "error conditions", - // hence why the name of the JSON file is different. We will address this inconsistency as part - // of this ticket: https://issues.apache.org/jira/browse/SPARK-47429 "common", "utils", "src", "main", "resources", "error", "error-conditions.json") private val errorReader = new ErrorClassesJsonReader(Seq(errorJsonFilePath.toUri.toURL)) @@ -81,8 +78,8 @@ class SparkThrowableSuite extends SparkFunSuite { mapper.readValue(errorJsonFilePath.toUri.toURL, new TypeReference[Map[String, ErrorInfo]]() {}) } - test("Error classes are correctly formatted") { - val errorClassFileContents = + test("Error conditions are correctly formatted") { + val errorConditionFileContents = IOUtils.toString(errorJsonFilePath.toUri.toURL.openStream(), StandardCharsets.UTF_8) val mapper = JsonMapper.builder() .addModule(DefaultScalaModule) @@ -96,33 +93,30 @@ class SparkThrowableSuite extends SparkFunSuite { .writeValueAsString(errorReader.errorInfoMap) if (regenerateGoldenFiles) { - if (rewrittenString.trim != errorClassFileContents.trim) { - val errorClassesFile = errorJsonFilePath.toFile - logInfo(s"Regenerating error class file $errorClassesFile") - Files.delete(errorClassesFile.toPath) + if (rewrittenString.trim != errorConditionFileContents.trim) { + val errorConditionsFile = errorJsonFilePath.toFile + logInfo(s"Regenerating error conditions file $errorConditionsFile") + Files.delete(errorConditionsFile.toPath) FileUtils.writeStringToFile( - errorClassesFile, + errorConditionsFile, rewrittenString + lineSeparator, StandardCharsets.UTF_8) } } else { - assert(rewrittenString.trim == errorClassFileContents.trim) + assert(rewrittenString.trim == errorConditionFileContents.trim) } } test("SQLSTATE is mandatory") { - val errorClassesNoSqlState = errorReader.errorInfoMap.filter { + val errorConditionsNoSqlState = errorReader.errorInfoMap.filter { case (error: String, info: ErrorInfo) => !error.startsWith("_LEGACY_ERROR_TEMP") && info.sqlState.isEmpty }.keys.toSeq - assert(errorClassesNoSqlState.isEmpty, - s"Error classes without SQLSTATE: ${errorClassesNoSqlState.mkString(", ")}") + assert(errorConditionsNoSqlState.isEmpty, + s"Error classes without SQLSTATE: ${errorConditionsNoSqlState.mkString(", ")}") } test("Error class and error state / SQLSTATE invariants") { - // Unlike in the rest of the codebase, the term "error class" is used here as it is in our - // documentation as well as in the SQL standard. We can remove this comment as part of this - // ticket: https://issues.apache.org/jira/browse/SPARK-47429 val errorClassesJson = Utils.getSparkClassLoader.getResource("error/error-classes.json") val errorStatesJson = Utils.getSparkClassLoader.getResource("error/error-states.json") val mapper = JsonMapper.builder() @@ -171,9 +165,9 @@ class SparkThrowableSuite extends SparkFunSuite { .enable(SerializationFeature.INDENT_OUTPUT) .build() mapper.writeValue(tmpFile, errorReader.errorInfoMap) - val rereadErrorClassToInfoMap = mapper.readValue( + val rereadErrorConditionToInfoMap = mapper.readValue( tmpFile, new TypeReference[Map[String, ErrorInfo]]() {}) - assert(rereadErrorClassToInfoMap == errorReader.errorInfoMap) + assert(rereadErrorConditionToInfoMap == errorReader.errorInfoMap) } test("Error class names should contain only capital letters, numbers and underscores") { @@ -207,13 +201,6 @@ class SparkThrowableSuite extends SparkFunSuite { } assert(e.getErrorClass === "INTERNAL_ERROR") assert(e.getMessageParameters().get("message").contains("Undefined error message parameter")) - - // Does not fail with too many args (expects 0 args) - assert(getMessage("DIVIDE_BY_ZERO", Map("config" -> "foo", "a" -> "bar")) == - "[DIVIDE_BY_ZERO] Division by zero. " + - "Use `try_divide` to tolerate divisor being 0 and return NULL instead. " + - "If necessary set foo to \"false\" " + - "to bypass this error. SQLSTATE: 22012") } test("Error message is formatted") { @@ -504,7 +491,7 @@ class SparkThrowableSuite extends SparkFunSuite { |{ | "MISSING_PARAMETER" : { | "message" : [ - | "Parameter ${param} is missing." + | "Parameter is missing." | ] | } |} @@ -517,4 +504,28 @@ class SparkThrowableSuite extends SparkFunSuite { assert(errorMessage.contains("Parameter null is missing.")) } } + + test("detect unused message parameters") { + checkError( + exception = intercept[SparkException] { + SparkThrowableHelper.getMessage( + errorClass = "CANNOT_UP_CAST_DATATYPE", + messageParameters = Map( + "expression" -> "CAST('aaa' AS LONG)", + "sourceType" -> "STRING", + "targetType" -> "LONG", + "op" -> "CAST", // unused parameter + "details" -> "implicit cast" + )) + }, + condition = "INTERNAL_ERROR", + parameters = Map( + "message" -> + ("Found unused message parameters of the error class 'CANNOT_UP_CAST_DATATYPE'. " + + "Its error message format has 4 placeholders, but the passed message parameters map " + + "has 5 items. Consider to add placeholders to the error format or " + + "remove unused message parameters.") + ) + ) + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 40d8eae644a07..ca81283e073ac 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -1802,6 +1802,23 @@ class SparkSubmitSuite val (_, classpath, _, _) = submit.prepareSubmitEnvironment(appArgs) assert(classpath.contains(".")) } + + // Requires Python dependencies for Spark Connect. Should be enabled by default. + ignore("Spark Connect application submission (Python)") { + val pyFile = File.createTempFile("remote_test", ".py") + pyFile.deleteOnExit() + val content = + "from pyspark.sql import SparkSession;" + + "spark = SparkSession.builder.getOrCreate();" + + "assert 'connect' in str(type(spark));" + + "assert spark.range(1).first()[0] == 0" + FileUtils.write(pyFile, content, StandardCharsets.UTF_8) + val args = Seq( + "--name", "testPyApp", + "--remote", "local", + pyFile.getAbsolutePath) + runSparkSubmit(args) + } } object JarCreationTest extends Logging { diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 5c09a1f965b9e..ff971b72d8910 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -132,7 +132,7 @@ class CompressionCodecSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { CompressionCodec.createCodec(conf, "foobar") }, - errorClass = "CODEC_NOT_AVAILABLE.WITH_CONF_SUGGESTION", + condition = "CODEC_NOT_AVAILABLE.WITH_CONF_SUGGESTION", parameters = Map( "codecName" -> "foobar", "configKey" -> "\"spark.io.compression.codec\"", @@ -171,7 +171,7 @@ class CompressionCodecSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { CompressionCodec.getShortName(codecClass.toUpperCase(Locale.ROOT)) }, - errorClass = "CODEC_SHORT_NAME_NOT_FOUND", + condition = "CODEC_SHORT_NAME_NOT_FOUND", parameters = Map("codecName" -> codecClass.toUpperCase(Locale.ROOT))) } } diff --git a/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala b/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala index 55d82aed5c3f2..817d660763361 100644 --- a/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala @@ -88,7 +88,7 @@ class GraphiteSinkSuite extends SparkFunSuite { val e = intercept[SparkException] { new GraphiteSink(props, registry) } - checkError(e, errorClass = "GRAPHITE_SINK_PROPERTY_MISSING", + checkError(e, condition = "GRAPHITE_SINK_PROPERTY_MISSING", parameters = Map("property" -> "host")) } @@ -100,7 +100,7 @@ class GraphiteSinkSuite extends SparkFunSuite { val e = intercept[SparkException] { new GraphiteSink(props, registry) } - checkError(e, errorClass = "GRAPHITE_SINK_PROPERTY_MISSING", + checkError(e, condition = "GRAPHITE_SINK_PROPERTY_MISSING", parameters = Map("property" -> "port")) } @@ -115,7 +115,7 @@ class GraphiteSinkSuite extends SparkFunSuite { exception = intercept[SparkException] { new GraphiteSink(props, registry) }, - errorClass = "GRAPHITE_SINK_INVALID_PROTOCOL", + condition = "GRAPHITE_SINK_INVALID_PROTOCOL", parameters = Map("protocol" -> "http") ) } diff --git a/core/src/test/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushReporterSuite.scala b/core/src/test/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushReporterSuite.scala deleted file mode 100644 index 3f9c75062f78f..0000000000000 --- a/core/src/test/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushReporterSuite.scala +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.metrics.sink.opentelemetry - -import com.codahale.metrics._ -import org.junit.jupiter.api.Assertions.assertNotNull -import org.scalatest.PrivateMethodTester - -import org.apache.spark.SparkFunSuite - -class OpenTelemetryPushReporterSuite - extends SparkFunSuite with PrivateMethodTester { - val reporter = new OpenTelemetryPushReporter( - registry = new MetricRegistry(), - trustedCertificatesPath = null, - privateKeyPemPath = null, - certificatePemPath = null - ) - - test("Normalize metric name key") { - val name = "local-1592132938718.driver.LiveListenerBus." + - "listenerProcessingTime.org.apache.spark.HeartbeatReceiver" - val metricsName = reporter invokePrivate PrivateMethod[String]( - Symbol("normalizeMetricName") - )(name) - assert( - metricsName == "local_1592132938718_driver_livelistenerbus_" + - "listenerprocessingtime_org_apache_spark_heartbeatreceiver" - ) - } - - test("OpenTelemetry actions when one codahale gauge is added") { - val gauge = new Gauge[Double] { - override def getValue: Double = 1.23 - } - reporter.onGaugeAdded("test-gauge", gauge) - assertNotNull(reporter.openTelemetryGauges("test_gauge")) - } - - test("OpenTelemetry actions when one codahale counter is added") { - val counter = new Counter - reporter.onCounterAdded("test_counter", counter) - assertNotNull(reporter.openTelemetryCounters("test_counter")) - } - - test("OpenTelemetry actions when one codahale histogram is added") { - val histogram = new Histogram(new UniformReservoir) - reporter.onHistogramAdded("test_hist", histogram) - assertNotNull(reporter.openTelemetryHistograms("test_hist_count")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_max")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_min")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_mean")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_std_dev")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_50_percentile")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_75_percentile")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_95_percentile")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_98_percentile")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_99_percentile")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_999_percentile")) - } - - test("OpenTelemetry actions when one codahale meter is added") { - val meter = new Meter() - reporter.onMeterAdded("test_meter", meter) - assertNotNull(reporter.openTelemetryGauges("test_meter_meter_count")) - assertNotNull(reporter.openTelemetryGauges("test_meter_meter_mean_rate")) - assertNotNull( - reporter.openTelemetryGauges("test_meter_meter_one_minute_rate") - ) - assertNotNull( - reporter.openTelemetryGauges("test_meter_meter_five_minute_rate") - ) - assertNotNull( - reporter.openTelemetryGauges("test_meter_meter_fifteen_minute_rate") - ) - } - - test("OpenTelemetry actions when one codahale timer is added") { - val timer = new Timer() - reporter.onTimerAdded("test_timer", timer) - assertNotNull(reporter.openTelemetryHistograms("test_timer_timer_count")) - assertNotNull(reporter.openTelemetryHistograms("test_timer_timer_max")) - assertNotNull(reporter.openTelemetryHistograms("test_timer_timer_min")) - assertNotNull(reporter.openTelemetryHistograms("test_timer_timer_mean")) - assertNotNull(reporter.openTelemetryHistograms("test_timer_timer_std_dev")) - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_50_percentile") - ) - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_75_percentile") - ) - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_95_percentile") - ) - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_95_percentile") - ) - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_99_percentile") - ) - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_999_percentile") - ) - - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_fifteen_minute_rate") - ) - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_five_minute_rate") - ) - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_one_minute_rate") - ) - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_mean_rate") - ) - } -} diff --git a/core/src/test/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushSinkSuite.scala b/core/src/test/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushSinkSuite.scala deleted file mode 100644 index 25aca82a22c40..0000000000000 --- a/core/src/test/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushSinkSuite.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.metrics.sink.opentelemetry - -import java.util.Properties - -import com.codahale.metrics._ -import org.junit.jupiter.api.Assertions.assertEquals -import org.scalatest.PrivateMethodTester - -import org.apache.spark.SparkFunSuite - -class OpenTelemetryPushSinkSuite - extends SparkFunSuite with PrivateMethodTester { - test("fetch properties map") { - val properties = new Properties - properties.put("foo1.foo2.foo3.foo4.header.key1.key2.key3", "value1") - properties.put("foo1.foo2.foo3.foo4.header.key2", "value2") - val keyPrefix = "foo1.foo2.foo3.foo4.header" - val propertiesMap: Map[String, String] = OpenTelemetryPushSink invokePrivate - PrivateMethod[Map[String, String]](Symbol("fetchMapFromProperties"))( - properties, - keyPrefix - ) - - assert("value1".equals(propertiesMap("key1.key2.key3"))) - assert("value2".equals(propertiesMap("key2"))) - } - - test("OpenTelemetry sink with one counter added") { - val props = new Properties - props.put("endpoint", "http://127.0.0.1:10086") - val registry = new MetricRegistry - val sink = new OpenTelemetryPushSink(props, registry) - sink.start() - val reporter = sink.reporter - val counter = registry.counter("test-counter") - assertEquals(reporter.openTelemetryCounters.size, 1) - } -} diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 7c5db914cd5ba..8bb96a0f53c73 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -922,7 +922,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { exception = intercept[SparkIllegalArgumentException] { rdd1.cartesian(rdd2).partitions }, - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.INITIALIZE", + condition = "COLLECTION_SIZE_LIMIT_EXCEEDED.INITIALIZE", sqlState = "54000", parameters = Map( "numberOfElements" -> (numSlices.toLong * numSlices.toLong).toString, diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 978ceb16b376c..243d33fe55a79 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -779,7 +779,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assert(failureReason.isDefined) checkError( exception = failureReason.get.asInstanceOf[SparkException], - errorClass = "SPARK_JOB_CANCELLED", + condition = "SPARK_JOB_CANCELLED", sqlState = "XXKDA", parameters = scala.collection.immutable.Map("jobId" -> "0", "reason" -> "") ) @@ -901,7 +901,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti cancel(jobId) checkError( exception = failure.asInstanceOf[SparkException], - errorClass = "SPARK_JOB_CANCELLED", + condition = "SPARK_JOB_CANCELLED", sqlState = "XXKDA", parameters = scala.collection.immutable.Map("jobId" -> jobId.toString, "reason" -> "") ) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala index 5b6fb31d598ac..aad649b7b2612 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala @@ -111,7 +111,7 @@ class ShuffleExternalSorterSuite extends SparkFunSuite with LocalSparkContext wi exception = intercept[SparkOutOfMemoryError] { sorter.insertRecord(bytes, Platform.BYTE_ARRAY_OFFSET, 1, 0) }, - errorClass = "UNABLE_TO_ACQUIRE_MEMORY", + condition = "UNABLE_TO_ACQUIRE_MEMORY", parameters = Map("requestedBytes" -> "800", "receivedBytes" -> "400")) } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index cdee6ccda706e..30c9693e6dee3 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -32,6 +32,7 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark._ import org.apache.spark.executor._ +import org.apache.spark.internal.config._ import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.rdd.{DeterministicLevel, RDDOperationScope} import org.apache.spark.resource._ @@ -276,7 +277,8 @@ class JsonProtocolSuite extends SparkFunSuite { test("StageInfo backward compatibility (details, accumulables)") { val info = makeStageInfo(1, 2, 3, 4L, 5L) - val newJson = toJsonString(JsonProtocol.stageInfoToJson(info, _, includeAccumulables = true)) + val newJson = toJsonString( + JsonProtocol.stageInfoToJson(info, _, defaultOptions, includeAccumulables = true)) // Fields added after 1.0.0. assert(info.details.nonEmpty) @@ -294,7 +296,8 @@ class JsonProtocolSuite extends SparkFunSuite { test("StageInfo resourceProfileId") { val info = makeStageInfo(1, 2, 3, 4L, 5L, 5) - val json = toJsonString(JsonProtocol.stageInfoToJson(info, _, includeAccumulables = true)) + val json = toJsonString( + JsonProtocol.stageInfoToJson(info, _, defaultOptions, includeAccumulables = true)) // Fields added after 1.0.0. assert(info.details.nonEmpty) @@ -471,7 +474,7 @@ class JsonProtocolSuite extends SparkFunSuite { stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown", resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID)) val jobStart = SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties) - val oldEvent = toJsonString(JsonProtocol.jobStartToJson(jobStart, _)).removeField("Stage Infos") + val oldEvent = sparkEventToJsonString(jobStart).removeField("Stage Infos") val expectedJobStart = SparkListenerJobStart(10, jobSubmissionTime, dummyStageInfos, properties) assertEquals(expectedJobStart, JsonProtocol.jobStartFromJson(oldEvent)) @@ -483,8 +486,7 @@ class JsonProtocolSuite extends SparkFunSuite { val stageIds = Seq[Int](1, 2, 3, 4) val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40L, x * 50L)) val jobStart = SparkListenerJobStart(11, jobSubmissionTime, stageInfos, properties) - val oldStartEvent = toJsonString(JsonProtocol.jobStartToJson(jobStart, _)) - .removeField("Submission Time") + val oldStartEvent = sparkEventToJsonString(jobStart).removeField("Submission Time") val expectedJobStart = SparkListenerJobStart(11, -1, stageInfos, properties) assertEquals(expectedJobStart, JsonProtocol.jobStartFromJson(oldStartEvent)) @@ -519,8 +521,9 @@ class JsonProtocolSuite extends SparkFunSuite { val stageInfo = new StageInfo(1, 1, "me-stage", 1, Seq.empty, Seq(1, 2, 3), "details", resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) val oldStageInfo = - toJsonString(JsonProtocol.stageInfoToJson(stageInfo, _, includeAccumulables = true)) - .removeField("Parent IDs") + toJsonString( + JsonProtocol.stageInfoToJson(stageInfo, _, defaultOptions, includeAccumulables = true) + ).removeField("Parent IDs") val expectedStageInfo = new StageInfo(1, 1, "me-stage", 1, Seq.empty, Seq.empty, "details", resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) assertEquals(expectedStageInfo, JsonProtocol.stageInfoFromJson(oldStageInfo)) @@ -785,6 +788,87 @@ class JsonProtocolSuite extends SparkFunSuite { assert(JsonProtocol.sparkEventFromJson(unknownFieldsJson) === expected) } + test("SPARK-42204: spark.eventLog.includeTaskMetricsAccumulators config") { + val includeConf = new JsonProtocolOptions( + new SparkConf().set(EVENT_LOG_INCLUDE_TASK_METRICS_ACCUMULATORS, true)) + val excludeConf = new JsonProtocolOptions( + new SparkConf().set(EVENT_LOG_INCLUDE_TASK_METRICS_ACCUMULATORS, false)) + + val taskMetricsAccumulables = TaskMetrics + .empty + .nameToAccums + .view + .filterKeys(!JsonProtocol.accumulableExcludeList.contains(_)) + .values + .map(_.toInfo(Some(1), None)) + .toSeq + + val taskInfoWithTaskMetricsAccums = makeTaskInfo(222L, 333, 1, 333, 444L, false) + taskInfoWithTaskMetricsAccums.setAccumulables(taskMetricsAccumulables) + val taskInfoWithoutTaskMetricsAccums = makeTaskInfo(222L, 333, 1, 333, 444L, false) + taskInfoWithoutTaskMetricsAccums.setAccumulables(Seq.empty) + + val stageInfoWithTaskMetricsAccums = makeStageInfo(100, 200, 300, 400L, 500L) + stageInfoWithTaskMetricsAccums.accumulables.clear() + stageInfoWithTaskMetricsAccums.accumulables ++= taskMetricsAccumulables.map(x => (x.id, x)) + val stageInfoWithoutTaskMetricsAccums = makeStageInfo(100, 200, 300, 400L, 500L) + stageInfoWithoutTaskMetricsAccums.accumulables.clear() + + // Test events which should be impacted by the config. + + // TaskEnd + { + val originalEvent = SparkListenerTaskEnd(1, 0, "ShuffleMapTask", Success, + taskInfoWithTaskMetricsAccums, + new ExecutorMetrics(Array(12L, 23L, 45L, 67L, 78L, 89L, + 90L, 123L, 456L, 789L, 40L, 20L, 20L, 10L, 20L, 10L, 301L)), + makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, 0, + hasHadoopInput = false, hasOutput = false)) + assertEquals( + originalEvent, + sparkEventFromJson(sparkEventToJsonString(originalEvent, includeConf))) + val trimmedEvent = originalEvent.copy(taskInfo = taskInfoWithoutTaskMetricsAccums) + assertEquals( + trimmedEvent, + sparkEventFromJson(sparkEventToJsonString(originalEvent, excludeConf))) + } + + // StageCompleted + { + val originalEvent = SparkListenerStageCompleted(stageInfoWithTaskMetricsAccums) + assertEquals( + originalEvent, + sparkEventFromJson(sparkEventToJsonString(originalEvent, includeConf))) + val trimmedEvent = originalEvent.copy(stageInfo = stageInfoWithoutTaskMetricsAccums) + assertEquals( + trimmedEvent, + sparkEventFromJson(sparkEventToJsonString(originalEvent, excludeConf))) + } + + // JobStart + { + val originalEvent = + SparkListenerJobStart(1, 1, Seq(stageInfoWithTaskMetricsAccums), properties) + assertEquals( + originalEvent, + sparkEventFromJson(sparkEventToJsonString(originalEvent, includeConf))) + val trimmedEvent = originalEvent.copy(stageInfos = Seq(stageInfoWithoutTaskMetricsAccums)) + assertEquals( + trimmedEvent, + sparkEventFromJson(sparkEventToJsonString(originalEvent, excludeConf))) + } + + // ExecutorMetricsUpdate events should be unaffected by the config: + val executorMetricsUpdate = + SparkListenerExecutorMetricsUpdate("0", Seq((0, 0, 0, taskMetricsAccumulables))) + assert( + sparkEventToJsonString(executorMetricsUpdate, includeConf) === + sparkEventToJsonString(executorMetricsUpdate, excludeConf)) + assertEquals( + JsonProtocol.sparkEventFromJson(sparkEventToJsonString(executorMetricsUpdate, includeConf)), + executorMetricsUpdate) + } + test("SPARK-42403: properly handle null string values") { // Null string values can appear in a few different event types, // so we test multiple known cases here: @@ -966,7 +1050,8 @@ private[spark] object JsonProtocolSuite extends Assertions { private def testStageInfo(info: StageInfo): Unit = { val newInfo = JsonProtocol.stageInfoFromJson( - toJsonString(JsonProtocol.stageInfoToJson(info, _, includeAccumulables = true))) + toJsonString( + JsonProtocol.stageInfoToJson(info, _, defaultOptions, includeAccumulables = true))) assertEquals(info, newInfo) } @@ -990,7 +1075,8 @@ private[spark] object JsonProtocolSuite extends Assertions { private def testTaskInfo(info: TaskInfo): Unit = { val newInfo = JsonProtocol.taskInfoFromJson( - toJsonString(JsonProtocol.taskInfoToJson(info, _, includeAccumulables = true))) + toJsonString( + JsonProtocol.taskInfoToJson(info, _, defaultOptions, includeAccumulables = true))) assertEquals(info, newInfo) } diff --git a/dev/.rat-excludes b/dev/.rat-excludes index f38fd7e2012a5..b82cb7078c9f3 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -140,3 +140,4 @@ ui-test/package.json ui-test/package-lock.json core/src/main/resources/org/apache/spark/ui/static/package.json .*\.har +.nojekyll diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile index e86b91968bf80..3cba72d042ed6 100644 --- a/dev/create-release/spark-rm/Dockerfile +++ b/dev/create-release/spark-rm/Dockerfile @@ -94,7 +94,7 @@ ENV R_LIBS_SITE "/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library RUN add-apt-repository ppa:pypy/ppa RUN mkdir -p /usr/local/pypy/pypy3.9 && \ curl -sqL https://downloads.python.org/pypy/pypy3.9-v7.3.16-linux64.tar.bz2 | tar xjf - -C /usr/local/pypy/pypy3.9 --strip-components=1 && \ - ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3.8 && \ + ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3.9 && \ ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3 RUN curl -sS https://bootstrap.pypa.io/get-pip.py | pypy3 RUN pypy3 -m pip install numpy 'six==1.16.0' 'pandas==2.2.2' scipy coverage matplotlib lxml @@ -137,6 +137,7 @@ RUN python3.9 -m pip list RUN gem install --no-document "bundler:2.4.22" RUN ln -s "$(which python3.9)" "/usr/local/bin/python" +RUN ln -s "$(which python3.9)" "/usr/local/bin/python3" WORKDIR /opt/spark-rm/output diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 8edb8b36cacb7..419625f48fa11 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -4,7 +4,7 @@ JTransforms/3.1//JTransforms-3.1.jar RoaringBitmap/1.2.1//RoaringBitmap-1.2.1.jar ST4/4.0.4//ST4-4.0.4.jar activation/1.1.1//activation-1.1.1.jar -aircompressor/0.27//aircompressor-0.27.jar +aircompressor/2.0.2//aircompressor-2.0.2.jar algebra_2.13/2.8.0//algebra_2.13-2.8.0.jar aliyun-java-sdk-core/4.5.10//aliyun-java-sdk-core-4.5.10.jar aliyun-java-sdk-kms/2.11.0//aliyun-java-sdk-kms-2.11.0.jar @@ -33,6 +33,7 @@ breeze-macros_2.13/2.1.0//breeze-macros_2.13-2.1.0.jar breeze_2.13/2.1.0//breeze_2.13-2.1.0.jar bundle/2.24.6//bundle-2.24.6.jar cats-kernel_2.13/2.8.0//cats-kernel_2.13-2.8.0.jar +checker-qual/3.42.0//checker-qual-3.42.0.jar chill-java/0.10.0//chill-java-0.10.0.jar chill_2.13/0.10.0//chill_2.13-0.10.0.jar commons-cli/1.9.0//commons-cli-1.9.0.jar @@ -43,7 +44,7 @@ commons-compiler/3.1.9//commons-compiler-3.1.9.jar commons-compress/1.27.1//commons-compress-1.27.1.jar commons-crypto/1.1.0//commons-crypto-1.1.0.jar commons-dbcp/1.4//commons-dbcp-1.4.jar -commons-io/2.16.1//commons-io-2.16.1.jar +commons-io/2.17.0//commons-io-2.17.0.jar commons-lang/2.6//commons-lang-2.6.jar commons-lang3/3.17.0//commons-lang3-3.17.0.jar commons-math3/3.6.1//commons-math3-3.6.1.jar @@ -62,12 +63,14 @@ derby/10.16.1.1//derby-10.16.1.1.jar derbyshared/10.16.1.1//derbyshared-10.16.1.1.jar derbytools/10.16.1.1//derbytools-10.16.1.1.jar dropwizard-metrics-hadoop-metrics2-reporter/0.1.2//dropwizard-metrics-hadoop-metrics2-reporter-0.1.2.jar +error_prone_annotations/2.26.1//error_prone_annotations-2.26.1.jar esdk-obs-java/3.20.4.2//esdk-obs-java-3.20.4.2.jar +failureaccess/1.0.2//failureaccess-1.0.2.jar flatbuffers-java/24.3.25//flatbuffers-java-24.3.25.jar gcs-connector/hadoop3-2.2.21/shaded/gcs-connector-hadoop3-2.2.21-shaded.jar gmetric4j/1.0.10//gmetric4j-1.0.10.jar gson/2.11.0//gson-2.11.0.jar -guava/14.0.1//guava-14.0.1.jar +guava/33.2.1-jre//guava-33.2.1-jre.jar hadoop-aliyun/3.4.0//hadoop-aliyun-3.4.0.jar hadoop-annotations/3.4.0//hadoop-annotations-3.4.0.jar hadoop-aws/3.4.0//hadoop-aws-3.4.0.jar @@ -101,6 +104,7 @@ icu4j/75.1//icu4j-75.1.jar ini4j/0.5.4//ini4j-0.5.4.jar istack-commons-runtime/3.0.8//istack-commons-runtime-3.0.8.jar ivy/2.5.2//ivy-2.5.2.jar +j2objc-annotations/3.0.0//j2objc-annotations-3.0.0.jar jackson-annotations/2.17.2//jackson-annotations-2.17.2.jar jackson-core-asl/1.9.13//jackson-core-asl-1.9.13.jar jackson-core/2.17.2//jackson-core-2.17.2.jar @@ -142,7 +146,7 @@ jjwt-api/0.12.6//jjwt-api-0.12.6.jar jline/2.14.6//jline-2.14.6.jar jline/3.25.1//jline-3.25.1.jar jna/5.14.0//jna-5.14.0.jar -joda-time/2.12.7//joda-time-2.12.7.jar +joda-time/2.13.0//joda-time-2.13.0.jar jodd-core/3.5.2//jodd-core-3.5.2.jar jpam/1.1//jpam-1.1.jar json/1.8//json-1.8.jar @@ -184,6 +188,7 @@ lapack/3.0.3//lapack-3.0.3.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar libthrift/0.16.0//libthrift-0.16.0.jar +listenablefuture/9999.0-empty-to-avoid-conflict-with-guava//listenablefuture-9999.0-empty-to-avoid-conflict-with-guava.jar log4j-1.2-api/2.22.1//log4j-1.2-api-2.22.1.jar log4j-api/2.22.1//log4j-api-2.22.1.jar log4j-core/2.22.1//log4j-core-2.22.1.jar @@ -207,12 +212,12 @@ netty-common/4.1.110.Final//netty-common-4.1.110.Final.jar netty-handler-proxy/4.1.110.Final//netty-handler-proxy-4.1.110.Final.jar netty-handler/4.1.110.Final//netty-handler-4.1.110.Final.jar netty-resolver/4.1.110.Final//netty-resolver-4.1.110.Final.jar -netty-tcnative-boringssl-static/2.0.65.Final/linux-aarch_64/netty-tcnative-boringssl-static-2.0.65.Final-linux-aarch_64.jar -netty-tcnative-boringssl-static/2.0.65.Final/linux-x86_64/netty-tcnative-boringssl-static-2.0.65.Final-linux-x86_64.jar -netty-tcnative-boringssl-static/2.0.65.Final/osx-aarch_64/netty-tcnative-boringssl-static-2.0.65.Final-osx-aarch_64.jar -netty-tcnative-boringssl-static/2.0.65.Final/osx-x86_64/netty-tcnative-boringssl-static-2.0.65.Final-osx-x86_64.jar -netty-tcnative-boringssl-static/2.0.65.Final/windows-x86_64/netty-tcnative-boringssl-static-2.0.65.Final-windows-x86_64.jar -netty-tcnative-classes/2.0.65.Final//netty-tcnative-classes-2.0.65.Final.jar +netty-tcnative-boringssl-static/2.0.66.Final/linux-aarch_64/netty-tcnative-boringssl-static-2.0.66.Final-linux-aarch_64.jar +netty-tcnative-boringssl-static/2.0.66.Final/linux-x86_64/netty-tcnative-boringssl-static-2.0.66.Final-linux-x86_64.jar +netty-tcnative-boringssl-static/2.0.66.Final/osx-aarch_64/netty-tcnative-boringssl-static-2.0.66.Final-osx-aarch_64.jar +netty-tcnative-boringssl-static/2.0.66.Final/osx-x86_64/netty-tcnative-boringssl-static-2.0.66.Final-osx-x86_64.jar +netty-tcnative-boringssl-static/2.0.66.Final/windows-x86_64/netty-tcnative-boringssl-static-2.0.66.Final-windows-x86_64.jar +netty-tcnative-classes/2.0.66.Final//netty-tcnative-classes-2.0.66.Final.jar netty-transport-classes-epoll/4.1.110.Final//netty-transport-classes-epoll-4.1.110.Final.jar netty-transport-classes-kqueue/4.1.110.Final//netty-transport-classes-kqueue-4.1.110.Final.jar netty-transport-native-epoll/4.1.110.Final/linux-aarch_64/netty-transport-native-epoll-4.1.110.Final-linux-aarch_64.jar @@ -236,12 +241,12 @@ orc-shims/2.0.2//orc-shims-2.0.2.jar oro/2.0.8//oro-2.0.8.jar osgi-resource-locator/1.0.3//osgi-resource-locator-1.0.3.jar paranamer/2.8//paranamer-2.8.jar -parquet-column/1.14.1//parquet-column-1.14.1.jar -parquet-common/1.14.1//parquet-common-1.14.1.jar -parquet-encoding/1.14.1//parquet-encoding-1.14.1.jar -parquet-format-structures/1.14.1//parquet-format-structures-1.14.1.jar -parquet-hadoop/1.14.1//parquet-hadoop-1.14.1.jar -parquet-jackson/1.14.1//parquet-jackson-1.14.1.jar +parquet-column/1.14.2//parquet-column-1.14.2.jar +parquet-common/1.14.2//parquet-common-1.14.2.jar +parquet-encoding/1.14.2//parquet-encoding-1.14.2.jar +parquet-format-structures/1.14.2//parquet-format-structures-1.14.2.jar +parquet-hadoop/1.14.2//parquet-hadoop-1.14.2.jar +parquet-jackson/1.14.2//parquet-jackson-1.14.2.jar pickle/1.5//pickle-1.5.jar py4j/0.10.9.7//py4j-0.10.9.7.jar remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar @@ -256,7 +261,7 @@ scala-xml_2.13/2.3.0//scala-xml_2.13-2.3.0.jar slf4j-api/2.0.16//slf4j-api-2.0.16.jar snakeyaml-engine/2.7//snakeyaml-engine-2.7.jar snakeyaml/2.2//snakeyaml-2.2.jar -snappy-java/1.1.10.6//snappy-java-1.1.10.6.jar +snappy-java/1.1.10.7//snappy-java-1.1.10.7.jar spire-macros_2.13/0.18.0//spire-macros_2.13-0.18.0.jar spire-platform_2.13/0.18.0//spire-platform_2.13-0.18.0.jar spire-util_2.13/0.18.0//spire-util_2.13-0.18.0.jar @@ -265,7 +270,7 @@ stax-api/1.0.1//stax-api-1.0.1.jar stream/2.9.8//stream-2.9.8.jar super-csv/2.2.0//super-csv-2.2.0.jar threeten-extra/1.7.1//threeten-extra-1.7.1.jar -tink/1.14.1//tink-1.14.1.jar +tink/1.15.0//tink-1.15.0.jar transaction-api/1.1//transaction-api-1.1.jar univocity-parsers/2.9.1//univocity-parsers-2.9.1.jar wildfly-openssl/1.1.3.Final//wildfly-openssl-1.1.3.Final.jar diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index b01e3c50e28d3..5939e429b2f35 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -24,7 +24,7 @@ LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image" # Overwrite this label to avoid exposing the underlying Ubuntu OS version label LABEL org.opencontainers.image.version="" -ENV FULL_REFRESH_DATE 20240318 +ENV FULL_REFRESH_DATE 20240903 ENV DEBIAN_FRONTEND noninteractive ENV DEBCONF_NONINTERACTIVE_SEEN true @@ -88,13 +88,13 @@ ENV R_LIBS_SITE "/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library RUN add-apt-repository ppa:pypy/ppa RUN mkdir -p /usr/local/pypy/pypy3.9 && \ curl -sqL https://downloads.python.org/pypy/pypy3.9-v7.3.16-linux64.tar.bz2 | tar xjf - -C /usr/local/pypy/pypy3.9 --strip-components=1 && \ - ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3.8 && \ + ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3.9 && \ ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3 RUN curl -sS https://bootstrap.pypa.io/get-pip.py | pypy3 -RUN pypy3 -m pip install numpy 'six==1.16.0' 'pandas==2.2.2' scipy coverage matplotlib lxml +RUN pypy3 -m pip install 'numpy==1.26.4' 'six==1.16.0' 'pandas==2.2.2' scipy coverage matplotlib lxml -ARG BASIC_PIP_PKGS="numpy pyarrow>=15.0.0 six==1.16.0 pandas==2.2.2 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" +ARG BASIC_PIP_PKGS="numpy==1.26.4 pyarrow>=15.0.0 six==1.16.0 pandas==2.2.2 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" # Python deps for Spark Connect ARG CONNECT_PIP_PKGS="grpcio==1.62.0 grpcio-status==1.62.0 protobuf==4.25.1 googleapis-common-protos==1.56.4 graphviz==0.20.3" diff --git a/dev/lint-scala b/dev/lint-scala index 98b850da68838..23df146a8d1b4 100755 --- a/dev/lint-scala +++ b/dev/lint-scala @@ -29,6 +29,7 @@ ERRORS=$(./build/mvn \ -Dscalafmt.skip=false \ -Dscalafmt.validateOnly=true \ -Dscalafmt.changedOnly=false \ + -pl sql/api \ -pl sql/connect/common \ -pl sql/connect/server \ -pl connector/connect/client/jvm \ @@ -38,7 +39,7 @@ ERRORS=$(./build/mvn \ if test ! -z "$ERRORS"; then echo -e "The scalafmt check failed on sql/connect or connector/connect at following occurrences:\n\n$ERRORS\n" echo "Before submitting your change, please make sure to format your code using the following command:" - echo "./build/mvn scalafmt:format -Dscalafmt.skip=false -Dscalafmt.validateOnly=false -Dscalafmt.changedOnly=false -pl sql/connect/common -pl sql/connect/server -pl connector/connect/client/jvm" + echo "./build/mvn scalafmt:format -Dscalafmt.skip=false -Dscalafmt.validateOnly=false -Dscalafmt.changedOnly=false -pl sql/api -pl sql/connect/common -pl sql/connect/server -pl connector/connect/client/jvm" exit 1 else echo -e "Scalafmt checks passed." diff --git a/dev/py-cleanup b/dev/py-cleanup new file mode 100755 index 0000000000000..6a2edd1040171 --- /dev/null +++ b/dev/py-cleanup @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility for temporary files cleanup in 'python'. +# usage: ./dev/py-cleanup + +set -ex + +SPARK_HOME="$(cd "`dirname $0`"/..; pwd)" +cd "$SPARK_HOME" + +rm -rf python/target +rm -rf python/lib/pyspark.zip +rm -rf python/docs/build +rm -rf python/docs/source/reference/*/api diff --git a/dev/requirements.txt b/dev/requirements.txt index e0216a63ba790..cafc73405aaa8 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -3,11 +3,11 @@ py4j>=0.10.9.7 # PySpark dependencies (optional) numpy>=1.21 -pyarrow>=4.0.0 +pyarrow>=10.0.0 six==1.16.0 -pandas>=1.4.4 +pandas>=2.0.0 scipy -plotly +plotly>=4.8 mlflow>=2.3.1 scikit-learn matplotlib diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 181cd28cda78d..b9a4bed715f67 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -206,7 +206,6 @@ def __hash__(self): sbt_test_goals=[ "core/test", ], - build_profile_flags=["-Popentelemetry"], ) api = Module( @@ -549,6 +548,8 @@ def __hash__(self): "pyspark.sql.tests.test_udtf", "pyspark.sql.tests.test_utils", "pyspark.sql.tests.test_resources", + "pyspark.sql.tests.plot.test_frame_plot", + "pyspark.sql.tests.plot.test_frame_plot_plotly", ], ) @@ -1052,6 +1053,8 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map", "pyspark.sql.tests.connect.test_parity_python_datasource", "pyspark.sql.tests.connect.test_parity_python_streaming_datasource", + "pyspark.sql.tests.connect.test_parity_frame_plot", + "pyspark.sql.tests.connect.test_parity_frame_plot_plotly", "pyspark.sql.tests.connect.test_utils", "pyspark.sql.tests.connect.client.test_artifact", "pyspark.sql.tests.connect.client.test_artifact_localcluster", diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 502113d11b77e..a85fd16451469 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -58,7 +58,7 @@
\ No newline at end of file + + +## Data Source Option + +Data source options of Protobuf can be set via: +* the built-in functions below + * `from_protobuf` + * `to_protobuf` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Property NameDefaultMeaningScope
modeFAILFASTAllows a mode for dealing with corrupt records during parsing.
+
    +
  • PERMISSIVE: when it meets a corrupted record, sets all fields to null.
  • +
  • DROPMALFORMED: ignores the whole corrupted records. This mode is unsupported in the Protobuf built-in functions.
  • +
  • FAILFAST: throws an exception when it meets corrupted records.
  • +
+
read
recursive.fields.max.depth-1Specifies the maximum number of recursion levels to allow when parsing the schema. For more details refers to the section Handling circular references protobuf fields.read
convert.any.fields.to.jsonfalseEnables converting Protobuf Any fields to JSON. This option should be enabled carefully. JSON conversion and processing are inefficient. In addition, schema safety is also reduced making downstream processing error-prone.read
emit.default.valuesfalseWhether to render fields with zero values when deserializing Protobuf to a Spark struct. When a field is empty in the serialized Protobuf, this library will deserialize them as null by default, this option can control whether to render the type-specific zero values.read
enums.as.intsfalseWhether to render enum fields as their integer values. When this option set to false, an enum field will be mapped to StringType, and the value is the name of enum; when set to true, an enum field will be mapped to IntegerType, the value is its integer value.read
upcast.unsigned.intsfalseWhether to upcast unsigned integers into a larger type. Setting this option to true, LongType is used for uint32 and Decimal(20, 0) is used for uint64, so their representation can contain large unsigned values without overflow.read
unwrap.primitive.wrapper.typesfalseWhether to unwrap the struct representation for well-known primitive wrapper types when deserializing. By default, the wrapper types for primitives (i.e. google.protobuf.Int32Value, google.protobuf.Int64Value, etc.) will get deserialized as structs.read
retain.empty.message.typesfalseWhether to retain fields of the empty proto message type in Schema. Since Spark doesn't allow writing empty StructType, the empty proto message type will be dropped by default. Setting this option to true will insert a dummy column(__dummy_field_in_empty_struct) to the empty proto message so that the empty message fields will be retained.read
diff --git a/docs/sql-distributed-sql-engine.md b/docs/sql-distributed-sql-engine.md index 734723f8c6235..ae8fd9c7211bd 100644 --- a/docs/sql-distributed-sql-engine.md +++ b/docs/sql-distributed-sql-engine.md @@ -83,7 +83,7 @@ Use the following setting to enable HTTP mode as system property or in `hive-sit To test, use beeline to connect to the JDBC/ODBC server in http mode with: - beeline> !connect jdbc:hive2://:/?hive.server2.transport.mode=http;hive.server2.thrift.http.path= + beeline> !connect jdbc:hive2://:/;transportMode=http;httpPath= If you closed a session and do CTAS, you must set `fs.%s.impl.disable.cache` to true in `hive-site.xml`. See more details in [[SPARK-21067]](https://issues.apache.org/jira/browse/SPARK-21067). @@ -94,4 +94,4 @@ To use the Spark SQL command line interface (CLI) from the shell: ./bin/spark-sql -For details, please refer to [Spark SQL CLI](sql-distributed-sql-engine-spark-sql-cli.html) \ No newline at end of file +For details, please refer to [Spark SQL CLI](sql-distributed-sql-engine-spark-sql-cli.html) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index ad678c44657ed..0ecd45c2d8c56 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -60,6 +60,7 @@ license: | - Since Spark 4.0, By default views tolerate column type changes in the query and compensate with casts. To restore the previous behavior, allowing up-casts only, set `spark.sql.legacy.viewSchemaCompensation` to `false`. - Since Spark 4.0, Views allow control over how they react to underlying query changes. By default views tolerate column type changes in the query and compensate with casts. To disable this feature set `spark.sql.legacy.viewSchemaBindingMode` to `false`. This also removes the clause from `DESCRIBE EXTENDED` and `SHOW CREATE TABLE`. - Since Spark 4.0, The Storage-Partitioned Join feature flag `spark.sql.sources.v2.bucketing.pushPartValues.enabled` is set to `true`. To restore the previous behavior, set `spark.sql.sources.v2.bucketing.pushPartValues.enabled` to `false`. +- Since Spark 4.0, the `sentences` function uses `Locale(language)` instead of `Locale.US` when `language` parameter is not `NULL` and `country` parameter is `NULL`. ## Upgrading from Spark SQL 3.5.1 to 3.5.2 diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index f5e1ddfd3c576..12dff1e325c49 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -141,13 +141,13 @@ In the table above, all the `CAST`s with new syntax are marked as red buildClassPath(String appClassPath) throws IOException { boolean prependClasses = !isEmpty(getenv("SPARK_PREPEND_CLASSES")); boolean isTesting = "1".equals(getenv("SPARK_TESTING")); + boolean isTestingSql = "1".equals(getenv("SPARK_SQL_TESTING")); + String jarsDir = findJarsDir(getSparkHome(), getScalaVersion(), !isTesting && !isTestingSql); if (prependClasses || isTesting) { String scala = getScalaVersion(); List projects = Arrays.asList( @@ -176,6 +178,9 @@ List buildClassPath(String appClassPath) throws IOException { "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark classes ahead of " + "assembly."); } + boolean shouldPrePendSparkHive = isJarAvailable(jarsDir, "spark-hive_"); + boolean shouldPrePendSparkHiveThriftServer = + shouldPrePendSparkHive && isJarAvailable(jarsDir, "spark-hive-thriftserver_"); for (String project : projects) { // Do not use locally compiled class files for Spark server because it should use shaded // dependencies. @@ -185,6 +190,24 @@ List buildClassPath(String appClassPath) throws IOException { if (isRemote && "1".equals(getenv("SPARK_SCALA_SHELL")) && project.equals("sql/core")) { continue; } + // SPARK-49534: The assumption here is that if `spark-hive_xxx.jar` is not in the + // classpath, then the `-Phive` profile was not used during package, and therefore + // the Hive-related jars should also not be in the classpath. To avoid failure in + // loading the SPI in `DataSourceRegister` under `sql/hive`, no longer prepend `sql/hive`. + if (!shouldPrePendSparkHive && project.equals("sql/hive")) { + continue; + } + // SPARK-49534: Meanwhile, due to the strong dependency of `sql/hive-thriftserver` + // on `sql/hive`, the prepend for `sql/hive-thriftserver` will also be excluded + // if `spark-hive_xxx.jar` is not in the classpath. On the other hand, if + // `spark-hive-thriftserver_xxx.jar` is not in the classpath, then the + // `-Phive-thriftserver` profile was not used during package, and therefore, + // jars such as hive-cli and hive-beeline should also not be included in the classpath. + // To avoid the inelegant startup failures of tools such as spark-sql, in this scenario, + // `sql/hive-thriftserver` will no longer be prepended to the classpath. + if (!shouldPrePendSparkHiveThriftServer && project.equals("sql/hive-thriftserver")) { + continue; + } addToClassPath(cp, String.format("%s/%s/target/scala-%s/classes", sparkHome, project, scala)); } @@ -205,8 +228,6 @@ List buildClassPath(String appClassPath) throws IOException { // Add Spark jars to the classpath. For the testing case, we rely on the test code to set and // propagate the test classpath appropriately. For normal invocation, look for the jars // directory under SPARK_HOME. - boolean isTestingSql = "1".equals(getenv("SPARK_SQL_TESTING")); - String jarsDir = findJarsDir(getSparkHome(), getScalaVersion(), !isTesting && !isTestingSql); if (jarsDir != null) { // Place slf4j-api-* jar first to be robust for (File f: new File(jarsDir).listFiles()) { @@ -214,7 +235,9 @@ List buildClassPath(String appClassPath) throws IOException { addToClassPath(cp, f.toString()); } } - if (isRemote && "1".equals(getenv("SPARK_SCALA_SHELL"))) { + // If we're in 'spark.local.connect', it should create a Spark Classic Spark Context + // that launches Spark Connect server. + if (isRemote && System.getenv("SPARK_LOCAL_CONNECT") == null) { for (File f: new File(jarsDir).listFiles()) { // Exclude Spark Classic SQL and Spark Connect server jars // if we're in Spark Connect Shell. Also exclude Spark SQL API and @@ -263,6 +286,24 @@ private void addToClassPath(Set cp, String entries) { } } + /** + * Checks if a JAR file with a specific prefix is available in the given directory. + * + * @param jarsDir the directory to search for JAR files + * @param jarNamePrefix the prefix of the JAR file name to look for + * @return true if a JAR file with the specified prefix is found, false otherwise + */ + private boolean isJarAvailable(String jarsDir, String jarNamePrefix) { + if (jarsDir != null) { + for (File f : new File(jarsDir).listFiles()) { + if (f.getName().startsWith(jarNamePrefix)) { + return true; + } + } + } + return false; + } + String getScalaVersion() { String scala = getenv("SPARK_SCALA_VERSION"); if (scala != null) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index eebd04fe4c5b1..8d95bc06d7a7d 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -82,10 +82,6 @@ public List buildCommand(Map env) javaOptsKeys.add("SPARK_BEELINE_OPTS"); yield "SPARK_BEELINE_MEMORY"; } - case "org.apache.spark.sql.application.ConnectRepl" -> { - isRemote = true; - yield "SPARK_DRIVER_MEMORY"; - } default -> "SPARK_DRIVER_MEMORY"; }; diff --git a/licenses-binary/LICENSE-xz.txt b/licenses-binary/LICENSE-xz.txt new file mode 100644 index 0000000000000..4322122aecf1a --- /dev/null +++ b/licenses-binary/LICENSE-xz.txt @@ -0,0 +1,11 @@ +Permission to use, copy, modify, and/or distribute this +software for any purpose with or without fee is hereby granted. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL +WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL +THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR +CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, +NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index b961f97cd877f..f97fefa245145 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -128,7 +128,7 @@ class CrossValidatorSuite exception = intercept[SparkIllegalArgumentException] { cv.fit(datasetWithFold) }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map( "fieldName" -> "`fold1`", "fields" -> "`label`, `features`, `fold`") diff --git a/pom.xml b/pom.xml index 1cf74ed907ab1..b7c87beec0f92 100644 --- a/pom.xml +++ b/pom.xml @@ -124,7 +124,7 @@ 3.4.0 - 3.25.4 + 3.25.5 3.11.4 ${hadoop.version} 3.9.2 @@ -137,7 +137,7 @@ 3.8.0 10.16.1.1 - 1.14.1 + 1.14.2 2.0.2 shaded-protobuf 11.0.23 @@ -183,11 +183,11 @@ 2.17.2 2.17.2 2.3.1 - 1.1.10.6 + 1.1.10.7 3.0.3 1.17.1 1.27.1 - 2.16.1 + 2.17.0 2.6 @@ -195,11 +195,11 @@ 2.12.0 4.1.17 - 14.0.1 + 33.2.1-jre 2.11.0 3.1.9 3.0.12 - 2.12.7 + 2.13.0 3.5.2 3.0.0 2.2.11 @@ -212,10 +212,10 @@ 1.1.0 1.9.0 1.78 - 1.14.1 + 1.15.0 6.0.0 4.1.110.Final - 2.0.65.Final + 2.0.66.Final 75.1 5.11.0 1.11.0 @@ -227,6 +227,7 @@ --> 17.0.0 3.0.0-M2 + 0.12.6 org.fusesource.leveldbjni @@ -276,6 +277,7 @@ compile compile test + test false @@ -677,6 +679,23 @@ ivy ${ivy.version}
+ + io.jsonwebtoken + jjwt-api + ${jjwt.version} + + + io.jsonwebtoken + jjwt-impl + ${jjwt.version} + ${jjwt.deps.scope} + + + io.jsonwebtoken + jjwt-jackson + ${jjwt.version} + ${jjwt.deps.scope} + com.google.code.findbugs jsr305 @@ -2615,7 +2634,7 @@ io.airlift aircompressor - 0.27 + 2.0.2 org.apache.orc @@ -3401,6 +3420,7 @@ org.spark-project.spark:unused com.google.guava:guava + com.google.guava:failureaccess org.jpmml:* @@ -3809,6 +3829,9 @@ sparkr + + jjwt + aarch64 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 21638e4816309..ece4504395f12 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -125,8 +125,64 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation$"), + // SPARK-49414: Remove Logging from DataFrameReader. + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.DataFrameReader"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logName"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.log"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logInfo"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logDebug"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logTrace"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logWarning"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logError"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logInfo"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logDebug"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logTrace"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logWarning"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logError"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.isTraceEnabled"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.initializeLogIfNecessary"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.initializeLogIfNecessary"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.initializeLogIfNecessary$default$2"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.initializeForcefully"), + // SPARK-49425: Create a shared DataFrameWriter interface. - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameWriter") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameWriter"), + + // SPARK-49284: Shared Catalog interface. + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.catalog.CatalogMetadata"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.catalog.Column"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.catalog.Database"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.catalog.Function"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.catalog.Table"), + + // SPARK-49426: Shared DataFrameWriterV2 + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CreateTableWriter"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameWriterV2"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.WriteConfigMethods"), + + // SPARK-49424: Shared Encoders + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders$"), + + // SPARK-49413: Create a shared RuntimeConfig interface. + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RuntimeConfig"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RuntimeConfig$"), + + // SPARK-49287: Shared Streaming interfaces + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.SparkListenerEvent"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ForeachWriter"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.SourceProgress"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.SourceProgress$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StateOperatorProgress"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StateOperatorProgress$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$Event"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryIdleEvent"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgressEvent"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryStartedEvent"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminatedEvent"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryStatus"), ) // Default exclude rules @@ -145,6 +201,8 @@ object MimaExcludes { ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.errors.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.classic.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.connect.*"), // DSv2 catalog and expression APIs are unstable yet. We should enable this back. ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.catalog.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.expressions.*"), diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b9d1d62c5ca5a..2f390cb70baa8 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -89,7 +89,7 @@ object BuildCommons { // Google Protobuf version used for generating the protobuf. // SPARK-41247: needs to be consistent with `protobuf.version` in `pom.xml`. - val protoVersion = "3.25.4" + val protoVersion = "3.25.5" // GRPC version used for Spark Connect. val grpcVersion = "1.62.2" } @@ -420,11 +420,6 @@ object SparkBuild extends PomBuild { enable(DockerIntegrationTests.settings)(dockerIntegrationTests) - if (!profiles.contains("volcano")) { - enable(Volcano.settings)(kubernetes) - enable(Volcano.settings)(kubernetesIntegrationTests) - } - enable(KubernetesIntegrationTests.settings)(kubernetesIntegrationTests) enable(YARN.settings)(yarn) @@ -433,10 +428,6 @@ object SparkBuild extends PomBuild { enable(SparkR.settings)(core) } - if (!profiles.contains("opentelemetry")) { - enable(OpenTelemetry.settings)(core) - } - /** * Adds the ability to run the spark shell directly from SBT without building an assembly * jar. @@ -1060,7 +1051,7 @@ object KubernetesIntegrationTests { * Overrides to work around sbt's dependency resolution being different from Maven's. */ object DependencyOverrides { - lazy val guavaVersion = sys.props.get("guava.version").getOrElse("14.0.1") + lazy val guavaVersion = sys.props.get("guava.version").getOrElse("33.1.0-jre") lazy val settings = Seq( dependencyOverrides += "com.google.guava" % "guava" % guavaVersion, dependencyOverrides += "xerces" % "xercesImpl" % "2.12.2", @@ -1326,20 +1317,6 @@ object SparkR { ) } -object Volcano { - // Exclude all volcano file for Compile and Test - lazy val settings = Seq( - unmanagedSources / excludeFilter := HiddenFileFilter || "*Volcano*.scala" - ) -} - -object OpenTelemetry { - // Exclude all OpenTelemetry files for Compile and Test - lazy val settings = Seq( - unmanagedSources / excludeFilter := HiddenFileFilter || "OpenTelemetry*.scala" - ) -} - trait SharedUnidocSettings { import BuildCommons._ @@ -1375,6 +1352,7 @@ trait SharedUnidocSettings { .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/kvstore"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/connect/"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/classic/"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/execution"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalog/v2/utils"))) diff --git a/python/docs/Makefile b/python/docs/Makefile index 5058c1206171b..428b0d24b568e 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -16,7 +16,7 @@ # Minimal makefile for Sphinx documentation # You can set these variables from the command line. -SPHINXOPTS ?= "-W" "-j" "auto" +SPHINXOPTS ?= "-W" "-j" "4" SPHINXBUILD ?= sphinx-build SOURCEDIR ?= source BUILDDIR ?= build diff --git a/python/docs/source/_static/spark-logo-dark.png b/python/docs/source/_static/spark-logo-dark.png deleted file mode 100644 index 7460faec37fc7..0000000000000 Binary files a/python/docs/source/_static/spark-logo-dark.png and /dev/null differ diff --git a/python/docs/source/_static/spark-logo-light.png b/python/docs/source/_static/spark-logo-light.png deleted file mode 100644 index 41938560822ca..0000000000000 Binary files a/python/docs/source/_static/spark-logo-light.png and /dev/null differ diff --git a/python/docs/source/conf.py b/python/docs/source/conf.py index 66b985092faf1..5640ba151176d 100644 --- a/python/docs/source/conf.py +++ b/python/docs/source/conf.py @@ -205,8 +205,8 @@ "navbar_end": ["version-switcher", "theme-switcher", "navbar-icon-links"], "footer_start": ["spark_footer", "sphinx-version"], "logo": { - "image_light": "_static/spark-logo-light.png", - "image_dark": "_static/spark-logo-dark.png", + "image_light": "https://spark.apache.org/images/spark-logo.png", + "image_dark": "https://spark.apache.org/images/spark-logo-rev.svg", }, "icon_links": [ { @@ -234,7 +234,7 @@ # The name of an image file (relative to this directory) to place at the top # of the sidebar. -html_logo = "../../../docs/img/spark-logo-reverse.png" +html_logo = "https://spark.apache.org/images/spark-logo-rev.svg" # The name of an image file (within the static path) to use as a favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst index 549656bea103e..88c0a8c26cc94 100644 --- a/python/docs/source/getting_started/install.rst +++ b/python/docs/source/getting_started/install.rst @@ -183,6 +183,7 @@ Package Supported version Note Additional libraries that enhance functionality but are not included in the installation packages: - **memory-profiler**: Used for PySpark UDF memory profiling, ``spark.profile.show(...)`` and ``spark.sql.pyspark.udf.profiler``. +- **plotly**: Used for PySpark plotting, ``DataFrame.plot``. Note that PySpark requires Java 17 or later with ``JAVA_HOME`` properly set and refer to |downloading|_. diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index dc4329c603241..4910a5b59273b 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -553,6 +553,7 @@ VARIANT Functions try_variant_get variant_get try_parse_json + to_variant_object XML Functions diff --git a/python/docs/source/user_guide/sql/python_data_source.rst b/python/docs/source/user_guide/sql/python_data_source.rst index 342b6f685d0b4..832987d19e5a4 100644 --- a/python/docs/source/user_guide/sql/python_data_source.rst +++ b/python/docs/source/user_guide/sql/python_data_source.rst @@ -452,3 +452,67 @@ We can also use the same data source in streaming reader and writer .. code-block:: python query = spark.readStream.format("fake").load().writeStream.format("fake").start("/output_path") + +Python Data Source Reader with direct Arrow Batch support for improved performance +---------------------------------------------------------------------------------- +The Python Datasource Reader supports direct yielding of Arrow Batches, which can significantly improve data processing performance. By using the efficient Arrow format, +this feature avoids the overhead of traditional row-by-row data processing, resulting in performance improvements of up to one order of magnitude, especially with large datasets. + +**Enabling Arrow Batch Support**: +To enable this feature, configure your custom DataSource to yield Arrow batches by returning `pyarrow.RecordBatch` objects within the `read` method of your `DataSourceReader` +(or `DataSourceStreamReader`) implementation. This method simplifies data handling and reduces the number of I/O operations, particularly beneficial for large-scale data processing tasks. + +**Arrow Batch Example**: +The following example demonstrates how to implement a basic Data Source using Arrow Batch support. + +.. code-block:: python + + from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition + from pyspark.sql import SparkSession + import pyarrow as pa + + # Define the ArrowBatchDataSource + class ArrowBatchDataSource(DataSource): + """ + A Data Source for testing Arrow Batch Serialization + """ + + @classmethod + def name(cls): + return "arrowbatch" + + def schema(self): + return "key int, value string" + + def reader(self, schema: str): + return ArrowBatchDataSourceReader(schema, self.options) + + # Define the ArrowBatchDataSourceReader + class ArrowBatchDataSourceReader(DataSourceReader): + def __init__(self, schema, options): + self.schema: str = schema + self.options = options + + def read(self, partition): + # Create Arrow Record Batch + keys = pa.array([1, 2, 3, 4, 5], type=pa.int32()) + values = pa.array(["one", "two", "three", "four", "five"], type=pa.string()) + schema = pa.schema([("key", pa.int32()), ("value", pa.string())]) + record_batch = pa.RecordBatch.from_arrays([keys, values], schema=schema) + yield record_batch + + def partitions(self): + # Define the number of partitions + num_part = 1 + return [InputPartition(i) for i in range(num_part)] + + # Initialize the Spark Session + spark = SparkSession.builder.appName("ArrowBatchExample").getOrCreate() + + # Register the ArrowBatchDataSource + spark.dataSource.register(ArrowBatchDataSource) + + # Load data using the custom data source + df = spark.read.format("arrowbatch").load() + + df.show() diff --git a/python/packaging/classic/setup.py b/python/packaging/classic/setup.py index 79b74483f00dd..17cca326d0241 100755 --- a/python/packaging/classic/setup.py +++ b/python/packaging/classic/setup.py @@ -288,6 +288,7 @@ def run(self): "pyspark.sql.connect.streaming.worker", "pyspark.sql.functions", "pyspark.sql.pandas", + "pyspark.sql.plot", "pyspark.sql.protobuf", "pyspark.sql.streaming", "pyspark.sql.worker", diff --git a/python/packaging/connect/setup.py b/python/packaging/connect/setup.py index ab166c79747df..6ae16e9a9ad3a 100755 --- a/python/packaging/connect/setup.py +++ b/python/packaging/connect/setup.py @@ -77,6 +77,7 @@ "pyspark.sql.tests.connect.client", "pyspark.sql.tests.connect.shell", "pyspark.sql.tests.pandas", + "pyspark.sql.tests.plot", "pyspark.sql.tests.streaming", "pyspark.ml.tests.connect", "pyspark.pandas.tests", @@ -161,6 +162,7 @@ "pyspark.sql.connect.streaming.worker", "pyspark.sql.functions", "pyspark.sql.pandas", + "pyspark.sql.plot", "pyspark.sql.protobuf", "pyspark.sql.streaming", "pyspark.sql.worker", diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 4061d024a83cd..92aeb15e21d1b 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -1088,6 +1088,11 @@ "Function `` should use only POSITIONAL or POSITIONAL OR KEYWORD arguments." ] }, + "UNSUPPORTED_PLOT_BACKEND": { + "message": [ + "`` is not supported, it should be one of the values from " + ] + }, "UNSUPPORTED_SIGNATURE": { "message": [ "Unsupported signature: ." diff --git a/python/pyspark/pandas/config.py b/python/pyspark/pandas/config.py index bfa88253dc6f4..6ed4adf21ff44 100644 --- a/python/pyspark/pandas/config.py +++ b/python/pyspark/pandas/config.py @@ -287,7 +287,8 @@ def validate(self, v: Any) -> None: doc=( "'plotting.sample_ratio' sets the proportion of data that will be plotted for sample-" "based plots such as `plot.line` and `plot.area`. " - "This option defaults to 'plotting.max_rows' option." + "If not set, it is derived from 'plotting.max_rows', by calculating the ratio of " + "'plotting.max_rows' to the total data size." ), default=None, types=(float, type(None)), diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index 92d4a3357319f..4be345201ba65 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -43,6 +43,7 @@ ) from pyspark.sql.utils import is_timestamp_ntz_preferred, is_remote from pyspark import pandas as ps +from pyspark.pandas.spark import functions as SF from pyspark.pandas._typing import Label from pyspark.pandas.spark.utils import as_nullable_spark_type, force_decimal_precision_scale from pyspark.pandas.data_type_ops.base import DataTypeOps @@ -938,19 +939,10 @@ def attach_distributed_sequence_column( +--------+---+ """ if len(sdf.columns) > 0: - if is_remote(): - from pyspark.sql.connect.column import Column as ConnectColumn - from pyspark.sql.connect.expressions import DistributedSequenceID - - return sdf.select( - ConnectColumn(DistributedSequenceID()).alias(column_name), - "*", - ) - else: - return PySparkDataFrame( - sdf._jdf.toDF().withSequenceColumn(column_name), - sdf.sparkSession, - ) + return sdf.select( + SF.distributed_sequence_id().alias(column_name), + "*", + ) else: cnt = sdf.count() if cnt > 0: diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index ea76dfa25bd99..6f036b7669246 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -15,7 +15,6 @@ # limitations under the License. # -import bisect import importlib import math @@ -25,7 +24,6 @@ from pandas.core.dtypes.inference import is_integer from pyspark.sql import functions as F, Column -from pyspark.sql.types import DoubleType from pyspark.pandas.spark import functions as SF from pyspark.pandas.missing import unsupported_function from pyspark.pandas.config import get_option @@ -70,19 +68,52 @@ class SampledPlotBase: def get_sampled(self, data): from pyspark.pandas import DataFrame, Series + if not isinstance(data, (DataFrame, Series)): + raise TypeError("Only DataFrame and Series are supported for plotting.") + if isinstance(data, Series): + data = data.to_frame() + fraction = get_option("plotting.sample_ratio") - if fraction is None: - fraction = 1 / (len(data) / get_option("plotting.max_rows")) - fraction = min(1.0, fraction) - self.fraction = fraction - - if isinstance(data, (DataFrame, Series)): - if isinstance(data, Series): - data = data.to_frame() + if fraction is not None: + self.fraction = fraction sampled = data._internal.resolved_copy.spark_frame.sample(fraction=self.fraction) return DataFrame(data._internal.with_new_sdf(sampled))._to_pandas() else: - raise TypeError("Only DataFrame and Series are supported for plotting.") + from pyspark.sql import Observation + + max_rows = get_option("plotting.max_rows") + observation = Observation("ps plotting") + sdf = data._internal.resolved_copy.spark_frame.observe( + observation, F.count(F.lit(1)).alias("count") + ) + + rand_col_name = "__ps_plotting_sampled_plot_base_rand__" + id_col_name = "__ps_plotting_sampled_plot_base_id__" + + sampled = ( + sdf.select( + "*", + F.rand().alias(rand_col_name), + F.monotonically_increasing_id().alias(id_col_name), + ) + .sort(rand_col_name) + .limit(max_rows + 1) + .coalesce(1) + .sortWithinPartitions(id_col_name) + .drop(rand_col_name, id_col_name) + ) + + pdf = DataFrame(data._internal.with_new_sdf(sampled))._to_pandas() + + if len(pdf) > max_rows: + try: + self.fraction = float(max_rows) / observation.get["count"] + except Exception: + pass + return pdf[:max_rows] + else: + self.fraction = 1.0 + return pdf def set_result_text(self, ax): assert hasattr(self, "fraction") @@ -182,22 +213,16 @@ def compute_hist(psdf, bins): colnames = sdf.columns bucket_names = ["__{}_bucket".format(colname) for colname in colnames] - # TODO(SPARK-49202): register this function in scala side - @F.udf(returnType=DoubleType()) - def binary_search_for_buckets(value): - # Given bins = [1.0, 2.0, 3.0, 4.0] - # the intervals are: - # [1.0, 2.0) -> 0.0 - # [2.0, 3.0) -> 1.0 - # [3.0, 4.0] -> 2.0 (the last bucket is a closed interval) - if value < bins[0] or value > bins[-1]: - raise ValueError(f"value {value} out of the bins bounds: [{bins[0]}, {bins[-1]}]") - - if value == bins[-1]: - idx = len(bins) - 2 - else: - idx = bisect.bisect(bins, value) - 1 - return float(idx) + # refers to org.apache.spark.ml.feature.Bucketizer#binarySearchForBuckets + def binary_search_for_buckets(value: Column): + index = SF.binary_search(F.lit(bins), value) + bucket = F.when(index >= 0, index).otherwise(-index - 2) + unboundErrMsg = F.lit(f"value %s out of the bins bounds: [{bins[0]}, {bins[-1]}]") + return ( + F.when(value == F.lit(bins[-1]), F.lit(len(bins) - 2)) + .when(value.between(F.lit(bins[0]), F.lit(bins[-1])), bucket) + .otherwise(F.raise_error(F.printf(unboundErrMsg, value))) + ) output_df = ( sdf.select( @@ -205,10 +230,10 @@ def binary_search_for_buckets(value): F.array([F.col(colname).cast("double") for colname in colnames]) ).alias("__group_id", "__value") ) - # to match handleInvalid="skip" in Bucketizer - .where(F.col("__value").isNotNull() & ~F.col("__value").isNaN()).select( + .where(F.col("__value").isNotNull() & ~F.col("__value").isNaN()) + .select( F.col("__group_id"), - binary_search_for_buckets(F.col("__value")).alias("__bucket"), + binary_search_for_buckets(F.col("__value")).cast("double").alias("__bucket"), ) ) @@ -454,7 +479,7 @@ class PandasOnSparkPlotAccessor(PandasObject): "pie": TopNPlotBase().get_top_n, "bar": TopNPlotBase().get_top_n, "barh": TopNPlotBase().get_top_n, - "scatter": TopNPlotBase().get_top_n, + "scatter": SampledPlotBase().get_sampled, "area": SampledPlotBase().get_sampled, "line": SampledPlotBase().get_sampled, } @@ -548,7 +573,7 @@ def line(self, x=None, y=None, **kwargs): """ Plot DataFrame/Series as lines. - This function is useful to plot lines using Series's values + This function is useful to plot lines using DataFrame’s values as coordinates. Parameters @@ -614,6 +639,12 @@ def bar(self, x=None, y=None, **kwds): """ Vertical bar plot. + A bar plot is a plot that presents categorical data with rectangular + bars with lengths proportional to the values that they represent. A + bar plot shows comparisons among discrete categories. One axis of the + plot shows the specific categories being compared, and the other axis + represents a measured value. + Parameters ---------- x : label or position, optional @@ -725,10 +756,10 @@ def barh(self, x=None, y=None, **kwargs): Parameters ---------- - x : label or position, default DataFrame.index - Column to be used for categories. - y : label or position, default All numeric columns in dataframe + x : label or position, default All numeric columns in dataframe Columns to be plotted from the DataFrame. + y : label or position, default DataFrame.index + Column to be used for categories. **kwds Keyword arguments to pass on to :meth:`pyspark.pandas.DataFrame.plot` or :meth:`pyspark.pandas.Series.plot`. @@ -739,6 +770,13 @@ def barh(self, x=None, y=None, **kwargs): Return an custom object when ``backend!=plotly``. Return an ndarray when ``subplots=True`` (matplotlib-only). + Notes + ----- + In Plotly and Matplotlib, the interpretation of `x` and `y` for `barh` plots differs. + In Plotly, `x` refers to the values and `y` refers to the categories. + In Matplotlib, `x` refers to the categories and `y` refers to the values. + Ensure correct axis labeling based on the backend used. + See Also -------- plotly.express.bar : Plot a vertical bar plot using plotly. @@ -805,7 +843,17 @@ def barh(self, x=None, y=None, **kwargs): def box(self, **kwds): """ - Make a box plot of the Series columns. + Make a box plot of the DataFrame columns. + + A box plot is a method for graphically depicting groups of numerical data through + their quartiles. The box extends from the Q1 to Q3 quartile values of the data, + with a line at the median (Q2). The whiskers extend from the edges of box to show + the range of the data. The position of the whiskers is set by default to + 1.5*IQR (IQR = Q3 - Q1) from the edges of the box. Outlier points are those past + the end of the whiskers. + + A consideration when using this chart is that the box and the whiskers can overlap, + which is very common when plotting small sets of data. Parameters ---------- @@ -859,9 +907,11 @@ def box(self, **kwds): def hist(self, bins=10, **kwds): """ Draw one histogram of the DataFrame’s columns. + A `histogram`_ is a representation of the distribution of data. This function calls :meth:`plotting.backend.plot`, on each series in the DataFrame, resulting in one histogram per column. + This is useful when the DataFrame’s Series are in a similar scale. .. _histogram: https://en.wikipedia.org/wiki/Histogram @@ -910,6 +960,10 @@ def kde(self, bw_method=None, ind=None, **kwargs): """ Generate Kernel Density Estimate plot using Gaussian kernels. + In statistics, kernel density estimation (KDE) is a non-parametric way to + estimate the probability density function (PDF) of a random variable. This + function uses Gaussian kernels and includes automatic bandwidth determination. + Parameters ---------- bw_method : scalar diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index 6bef3d9b87c05..4bcf07f6f6503 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -174,6 +174,18 @@ def null_index(col: Column) -> Column: return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc)) +def distributed_sequence_id() -> Column: + if is_remote(): + from pyspark.sql.connect.functions.builtin import _invoke_function + + return _invoke_function("distributed_sequence_id") + else: + from pyspark import SparkContext + + sc = SparkContext._active_spark_context + return Column(sc._jvm.PythonSQLUtils.distributed_sequence_id()) + + def collect_top_k(col: Column, num: int, reverse: bool) -> Column: if is_remote(): from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns @@ -187,6 +199,19 @@ def collect_top_k(col: Column, num: int, reverse: bool) -> Column: return Column(sc._jvm.PythonSQLUtils.collect_top_k(col._jc, num, reverse)) +def binary_search(col: Column, value: Column) -> Column: + if is_remote(): + from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns + + return _invoke_function_over_columns("array_binary_search", col, value) + + else: + from pyspark import SparkContext + + sc = SparkContext._active_spark_context + return Column(sc._jvm.PythonSQLUtils.binary_search(col._jc, value._jc)) + + def make_interval(unit: str, e: Union[Column, int, float]) -> Column: unit_mapping = { "YEAR": "years", diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py index 37469db2c8f51..8d197649aaebe 100644 --- a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py @@ -105,9 +105,10 @@ def check_barh_plot_with_x_y(pdf, psdf, x, y): self.assertEqual(pdf.plot.barh(x=x, y=y), psdf.plot.barh(x=x, y=y)) # this is testing plot with specified x and y - pdf1 = pd.DataFrame({"lab": ["A", "B", "C"], "val": [10, 30, 20]}) + pdf1 = pd.DataFrame({"lab": ["A", "B", "C"], "val": [10, 30, 20], "val2": [1.1, 2.2, 3.3]}) psdf1 = ps.from_pandas(pdf1) - check_barh_plot_with_x_y(pdf1, psdf1, x="lab", y="val") + check_barh_plot_with_x_y(pdf1, psdf1, x="val", y="lab") + check_barh_plot_with_x_y(pdf1, psdf1, x=["val", "val2"], y="lab") def test_barh_plot(self): def check_barh_plot(pdf, psdf): diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index 0e890e3343e66..a2778cbc32c4c 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -73,6 +73,11 @@ from pyspark.sql.pandas.conversion import PandasConversionMixin from pyspark.sql.pandas.map_ops import PandasMapOpsMixin +try: + from pyspark.sql.plot import PySparkPlotAccessor +except ImportError: + PySparkPlotAccessor = None # type: ignore + if TYPE_CHECKING: from py4j.java_gateway import JavaObject import pyarrow as pa @@ -1849,6 +1854,12 @@ def toArrow(self) -> "pa.Table": def toPandas(self) -> "PandasDataFrameLike": return PandasConversionMixin.toPandas(self) + def transpose(self, indexColumn: Optional["ColumnOrName"] = None) -> ParentDataFrame: + if indexColumn is not None: + return DataFrame(self._jdf.transpose(_to_java_column(indexColumn)), self.sparkSession) + else: + return DataFrame(self._jdf.transpose(), self.sparkSession) + @property def executionInfo(self) -> Optional["ExecutionInfo"]: raise PySparkValueError( @@ -1856,6 +1867,10 @@ def executionInfo(self) -> Optional["ExecutionInfo"]: messageParameters={"member": "queryExecution"}, ) + @property + def plot(self) -> PySparkPlotAccessor: + return PySparkPlotAccessor(self) + class DataFrameNaFunctions(ParentDataFrameNaFunctions): def __init__(self, df: ParentDataFrame): diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 35dcf677fdb70..adba1b42a8bd6 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1786,7 +1786,7 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet req.user_context.user_id = self._user_id try: - return self._stub.FetchErrorDetails(req) + return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata()) except grpc.RpcError: return None diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index 0309a25f956a1..ea6788e858317 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -284,8 +284,14 @@ def _call_iter(self, iter_fun: Callable) -> Any: raise e def _create_reattach_execute_request(self) -> pb2.ReattachExecuteRequest: + server_side_session_id = ( + None + if not self._initial_request.client_observed_server_side_session_id + else self._initial_request.client_observed_server_side_session_id + ) reattach = pb2.ReattachExecuteRequest( session_id=self._initial_request.session_id, + client_observed_server_side_session_id=server_side_session_id, user_context=self._initial_request.user_context, operation_id=self._initial_request.operation_id, ) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 442157eef0b75..59d79decf6690 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -86,6 +86,10 @@ from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined] +try: + from pyspark.sql.plot import PySparkPlotAccessor +except ImportError: + PySparkPlotAccessor = None # type: ignore if TYPE_CHECKING: from pyspark.sql.connect._typing import ( @@ -1783,7 +1787,7 @@ def __getitem__( ) ) else: - # TODO: revisit vanilla Spark's Dataset.col + # TODO: revisit classic Spark's Dataset.col # if (sparkSession.sessionState.conf.supportQuotedRegexColumnName) { # colRegex(colName) # } else { @@ -1858,6 +1862,12 @@ def toPandas(self) -> "PandasDataFrameLike": self._execution_info = ei return pdf + def transpose(self, indexColumn: Optional["ColumnOrName"] = None) -> ParentDataFrame: + return DataFrame( + plan.Transpose(self._plan, [F._to_col(indexColumn)] if indexColumn is not None else []), + self._session, + ) + @property def schema(self) -> StructType: # Schema caching is correct in most cases. Connect is lazy by nature. This means that @@ -2233,6 +2243,10 @@ def rdd(self) -> "RDD[Row]": def executionInfo(self) -> Optional["ExecutionInfo"]: return self._execution_info + @property + def plot(self) -> PySparkPlotAccessor: + return PySparkPlotAccessor(self) + class DataFrameNaFunctions(ParentDataFrameNaFunctions): def __init__(self, df: ParentDataFrame): diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index db1cd1c013be5..0b5512b61925c 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -477,8 +477,30 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": def __repr__(self) -> str: if self._value is None: return "NULL" - else: - return f"{self._value}" + elif isinstance(self._dataType, DateType): + dt = DateType().fromInternal(self._value) + if dt is not None and isinstance(dt, datetime.date): + return dt.strftime("%Y-%m-%d") + elif isinstance(self._dataType, TimestampType): + ts = TimestampType().fromInternal(self._value) + if ts is not None and isinstance(ts, datetime.datetime): + return ts.strftime("%Y-%m-%d %H:%M:%S.%f") + elif isinstance(self._dataType, TimestampNTZType): + ts = TimestampNTZType().fromInternal(self._value) + if ts is not None and isinstance(ts, datetime.datetime): + return ts.strftime("%Y-%m-%d %H:%M:%S.%f") + elif isinstance(self._dataType, DayTimeIntervalType): + delta = DayTimeIntervalType().fromInternal(self._value) + if delta is not None and isinstance(delta, datetime.timedelta): + import pandas as pd + + # Note: timedelta itself does not provide isoformat method. + # Both Pandas and java.time.Duration provide it, but the format + # is sightly different: + # java.time.Duration only applies HOURS, MINUTES, SECONDS units, + # while Pandas applies all supported units. + return pd.Timedelta(delta).isoformat() # type: ignore[attr-defined] + return f"{self._value}" class ColumnReference(Expression): diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index ad6dbbf58e48d..7fed175cbc8ea 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -71,6 +71,7 @@ StringType, ) from pyspark.sql.utils import enum_to_value as _enum_to_value +from pyspark.util import JVM_INT_MAX # The implementation of pandas_udf is embedded in pyspark.sql.function.pandas_udf # for code reuse. @@ -1126,11 +1127,12 @@ def grouping_id(*cols: "ColumnOrName") -> Column: def count_min_sketch( col: "ColumnOrName", - eps: "ColumnOrName", - confidence: "ColumnOrName", - seed: "ColumnOrName", + eps: Union[Column, float], + confidence: Union[Column, float], + seed: Optional[Union[Column, int]] = None, ) -> Column: - return _invoke_function_over_columns("count_min_sketch", col, eps, confidence, seed) + _seed = lit(random.randint(0, JVM_INT_MAX)) if seed is None else lit(seed) + return _invoke_function_over_columns("count_min_sketch", col, lit(eps), lit(confidence), _seed) count_min_sketch.__doc__ = pysparkfuncs.count_min_sketch.__doc__ @@ -2071,6 +2073,13 @@ def try_parse_json(col: "ColumnOrName") -> Column: try_parse_json.__doc__ = pysparkfuncs.try_parse_json.__doc__ +def to_variant_object(col: "ColumnOrName") -> Column: + return _invoke_function("to_variant_object", _to_col(col)) + + +to_variant_object.__doc__ = pysparkfuncs.to_variant_object.__doc__ + + def parse_json(col: "ColumnOrName") -> Column: return _invoke_function("parse_json", _to_col(col)) @@ -2481,8 +2490,14 @@ def sentences( sentences.__doc__ = pysparkfuncs.sentences.__doc__ -def substring(str: "ColumnOrName", pos: int, len: int) -> Column: - return _invoke_function("substring", _to_col(str), lit(pos), lit(len)) +def substring( + str: "ColumnOrName", + pos: Union["ColumnOrName", int], + len: Union["ColumnOrName", int], +) -> Column: + _pos = lit(pos) if isinstance(pos, int) else _to_col(pos) + _len = lit(len) if isinstance(len, int) else _to_col(len) + return _invoke_function("substring", _to_col(str), _pos, _len) substring.__doc__ = pysparkfuncs.substring.__doc__ diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index 69fbcda12ae72..46f13f893c7fa 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -101,7 +101,7 @@ def __init__( def __repr__(self) -> str: # the expressions are not resolved here, - # so the string representation can be different from vanilla PySpark. + # so the string representation can be different from classic PySpark. grouping_str = ", ".join(str(e._expr) for e in self._grouping_cols) grouping_str = f"grouping expressions: [{grouping_str}]" diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 958626280e41c..fbed0eabc684f 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -1329,6 +1329,27 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation: return plan +class Transpose(LogicalPlan): + """Logical plan object for a transpose operation.""" + + def __init__( + self, + child: Optional["LogicalPlan"], + index_columns: Sequence[Column], + ) -> None: + super().__init__(child) + self.index_columns = index_columns + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + assert self._child is not None + plan = self._create_proto_relation() + plan.transpose.input.CopyFrom(self._child.plan(session)) + if self.index_columns is not None and len(self.index_columns) > 0: + for index_column in self.index_columns: + plan.transpose.index_columns.append(index_column.to_plan(session)) + return plan + + class CollectMetrics(LogicalPlan): """Logical plan object for a CollectMetrics operation.""" diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 9f4d1e717a28d..ee625241600ff 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -36,7 +36,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto"\xe9\x1a\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12*\n\x05to_df\x18\x12 \x01(\x0b\x32\x13.spark.connect.ToDFH\x00R\x04toDf\x12U\n\x14with_columns_renamed\x18\x13 \x01(\x0b\x32!.spark.connect.WithColumnsRenamedH\x00R\x12withColumnsRenamed\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12)\n\x04tail\x18\x16 \x01(\x0b\x32\x13.spark.connect.TailH\x00R\x04tail\x12?\n\x0cwith_columns\x18\x17 \x01(\x0b\x32\x1a.spark.connect.WithColumnsH\x00R\x0bwithColumns\x12)\n\x04hint\x18\x18 \x01(\x0b\x32\x13.spark.connect.HintH\x00R\x04hint\x12\x32\n\x07unpivot\x18\x19 \x01(\x0b\x32\x16.spark.connect.UnpivotH\x00R\x07unpivot\x12\x36\n\tto_schema\x18\x1a \x01(\x0b\x32\x17.spark.connect.ToSchemaH\x00R\x08toSchema\x12\x64\n\x19repartition_by_expression\x18\x1b \x01(\x0b\x32&.spark.connect.RepartitionByExpressionH\x00R\x17repartitionByExpression\x12\x45\n\x0emap_partitions\x18\x1c \x01(\x0b\x32\x1c.spark.connect.MapPartitionsH\x00R\rmapPartitions\x12H\n\x0f\x63ollect_metrics\x18\x1d \x01(\x0b\x32\x1d.spark.connect.CollectMetricsH\x00R\x0e\x63ollectMetrics\x12,\n\x05parse\x18\x1e \x01(\x0b\x32\x14.spark.connect.ParseH\x00R\x05parse\x12\x36\n\tgroup_map\x18\x1f \x01(\x0b\x32\x17.spark.connect.GroupMapH\x00R\x08groupMap\x12=\n\x0c\x63o_group_map\x18 \x01(\x0b\x32\x19.spark.connect.CoGroupMapH\x00R\ncoGroupMap\x12\x45\n\x0ewith_watermark\x18! \x01(\x0b\x32\x1c.spark.connect.WithWatermarkH\x00R\rwithWatermark\x12\x63\n\x1a\x61pply_in_pandas_with_state\x18" \x01(\x0b\x32%.spark.connect.ApplyInPandasWithStateH\x00R\x16\x61pplyInPandasWithState\x12<\n\x0bhtml_string\x18# \x01(\x0b\x32\x19.spark.connect.HtmlStringH\x00R\nhtmlString\x12X\n\x15\x63\x61\x63hed_local_relation\x18$ \x01(\x0b\x32".spark.connect.CachedLocalRelationH\x00R\x13\x63\x61\x63hedLocalRelation\x12[\n\x16\x63\x61\x63hed_remote_relation\x18% \x01(\x0b\x32#.spark.connect.CachedRemoteRelationH\x00R\x14\x63\x61\x63hedRemoteRelation\x12\x8e\x01\n)common_inline_user_defined_table_function\x18& \x01(\x0b\x32\x33.spark.connect.CommonInlineUserDefinedTableFunctionH\x00R$commonInlineUserDefinedTableFunction\x12\x37\n\nas_of_join\x18\' \x01(\x0b\x32\x17.spark.connect.AsOfJoinH\x00R\x08\x61sOfJoin\x12\x85\x01\n&common_inline_user_defined_data_source\x18( \x01(\x0b\x32\x30.spark.connect.CommonInlineUserDefinedDataSourceH\x00R!commonInlineUserDefinedDataSource\x12\x45\n\x0ewith_relations\x18) \x01(\x0b\x32\x1c.spark.connect.WithRelationsH\x00R\rwithRelations\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x30\n\x07\x64rop_na\x18[ \x01(\x0b\x32\x15.spark.connect.NADropH\x00R\x06\x64ropNa\x12\x34\n\x07replace\x18\\ \x01(\x0b\x32\x18.spark.connect.NAReplaceH\x00R\x07replace\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x39\n\x08\x64\x65scribe\x18\x66 \x01(\x0b\x32\x1b.spark.connect.StatDescribeH\x00R\x08\x64\x65scribe\x12*\n\x03\x63ov\x18g \x01(\x0b\x32\x16.spark.connect.StatCovH\x00R\x03\x63ov\x12-\n\x04\x63orr\x18h \x01(\x0b\x32\x17.spark.connect.StatCorrH\x00R\x04\x63orr\x12L\n\x0f\x61pprox_quantile\x18i \x01(\x0b\x32!.spark.connect.StatApproxQuantileH\x00R\x0e\x61pproxQuantile\x12=\n\nfreq_items\x18j \x01(\x0b\x32\x1c.spark.connect.StatFreqItemsH\x00R\tfreqItems\x12:\n\tsample_by\x18k \x01(\x0b\x32\x1b.spark.connect.StatSampleByH\x00R\x08sampleBy\x12\x33\n\x07\x63\x61talog\x18\xc8\x01 \x01(\x0b\x32\x16.spark.connect.CatalogH\x00R\x07\x63\x61talog\x12\x35\n\textension\x18\xe6\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"\x8e\x01\n\x0eRelationCommon\x12#\n\x0bsource_info\x18\x01 \x01(\tB\x02\x18\x01R\nsourceInfo\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x12-\n\x06origin\x18\x03 \x01(\x0b\x32\x15.spark.connect.OriginR\x06originB\n\n\x08_plan_id"\xde\x03\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query\x12\x34\n\x04\x61rgs\x18\x02 \x03(\x0b\x32\x1c.spark.connect.SQL.ArgsEntryB\x02\x18\x01R\x04\x61rgs\x12@\n\x08pos_args\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralB\x02\x18\x01R\x07posArgs\x12O\n\x0fnamed_arguments\x18\x04 \x03(\x0b\x32&.spark.connect.SQL.NamedArgumentsEntryR\x0enamedArguments\x12>\n\rpos_arguments\x18\x05 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0cposArguments\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01\x1a\\\n\x13NamedArgumentsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value:\x02\x38\x01"u\n\rWithRelations\x12+\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04root\x12\x37\n\nreferences\x18\x02 \x03(\x0b\x32\x17.spark.connect.RelationR\nreferences"\x97\x05\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x12!\n\x0cis_streaming\x18\x03 \x01(\x08R\x0bisStreaming\x1a\xc0\x01\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x45\n\x07options\x18\x02 \x03(\x0b\x32+.spark.connect.Read.NamedTable.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x95\x02\n\nDataSource\x12\x1b\n\x06\x66ormat\x18\x01 \x01(\tH\x00R\x06\x66ormat\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x12\x14\n\x05paths\x18\x04 \x03(\tR\x05paths\x12\x1e\n\npredicates\x18\x05 \x03(\tR\npredicates\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_formatB\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x95\x05\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns\x12K\n\x0ejoin_data_type\x18\x06 \x01(\x0b\x32 .spark.connect.Join.JoinDataTypeH\x00R\x0cjoinDataType\x88\x01\x01\x1a\\\n\x0cJoinDataType\x12$\n\x0eis_left_struct\x18\x01 \x01(\x08R\x0cisLeftStruct\x12&\n\x0fis_right_struct\x18\x02 \x01(\x08R\risRightStruct"\xd0\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06\x12\x13\n\x0fJOIN_TYPE_CROSS\x10\x07\x42\x11\n\x0f_join_data_type"\xdf\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01\x12\x37\n\x15\x61llow_missing_columns\x18\x06 \x01(\x08H\x02R\x13\x61llowMissingColumns\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_nameB\x18\n\x16_allow_missing_columns"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"K\n\x04Tail\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"\xfe\x05\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x41\n\ngroup_type\x18\x02 \x01(\x0e\x32".spark.connect.Aggregate.GroupTypeR\tgroupType\x12L\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12N\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x14\x61ggregateExpressions\x12\x34\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.spark.connect.Aggregate.PivotR\x05pivot\x12J\n\rgrouping_sets\x18\x06 \x03(\x0b\x32%.spark.connect.Aggregate.GroupingSetsR\x0cgroupingSets\x1ao\n\x05Pivot\x12+\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1aL\n\x0cGroupingSets\x12<\n\x0cgrouping_set\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0bgroupingSet"\x9f\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04\x12\x1c\n\x18GROUP_TYPE_GROUPING_SETS\x10\x05"\xa0\x01\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x05order\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\x05order\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x42\x0c\n\n_is_global"\x8d\x01\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x33\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07\x63olumns\x12!\n\x0c\x63olumn_names\x18\x03 \x03(\tR\x0b\x63olumnNames"\xf0\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x12.\n\x10within_watermark\x18\x04 \x01(\x08H\x01R\x0fwithinWatermark\x88\x01\x01\x42\x16\n\x14_all_columns_as_keysB\x13\n\x11_within_watermark"Y\n\rLocalRelation\x12\x17\n\x04\x64\x61ta\x18\x01 \x01(\x0cH\x00R\x04\x64\x61ta\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x42\x07\n\x05_dataB\t\n\x07_schema"H\n\x13\x43\x61\x63hedLocalRelation\x12\x12\n\x04hash\x18\x03 \x01(\tR\x04hashJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03R\x06userIdR\tsessionId"7\n\x14\x43\x61\x63hedRemoteRelation\x12\x1f\n\x0brelation_id\x18\x01 \x01(\tR\nrelationId"\x91\x02\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x12/\n\x13\x64\x65terministic_order\x18\x06 \x01(\x08R\x12\x64\x65terministicOrderB\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8e\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"r\n\nHtmlString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"Q\n\x0cStatDescribe\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"`\n\x07StatCov\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x89\x01\n\x08StatCorr\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2\x12\x1b\n\x06method\x18\x04 \x01(\tH\x00R\x06method\x88\x01\x01\x42\t\n\x07_method"\xa4\x01\n\x12StatApproxQuantile\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12$\n\rprobabilities\x18\x03 \x03(\x01R\rprobabilities\x12%\n\x0erelative_error\x18\x04 \x01(\x01R\rrelativeError"}\n\rStatFreqItems\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x1d\n\x07support\x18\x03 \x01(\x01H\x00R\x07support\x88\x01\x01\x42\n\n\x08_support"\xb5\x02\n\x0cStatSampleBy\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03\x63ol\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x42\n\tfractions\x18\x03 \x03(\x0b\x32$.spark.connect.StatSampleBy.FractionR\tfractions\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x00R\x04seed\x88\x01\x01\x1a\x63\n\x08\x46raction\x12;\n\x07stratum\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x07stratum\x12\x1a\n\x08\x66raction\x18\x02 \x01(\x01R\x08\x66ractionB\x07\n\x05_seed"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x86\x01\n\x06NADrop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\'\n\rmin_non_nulls\x18\x03 \x01(\x05H\x00R\x0bminNonNulls\x88\x01\x01\x42\x10\n\x0e_min_non_nulls"\xa8\x02\n\tNAReplace\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12H\n\x0creplacements\x18\x03 \x03(\x0b\x32$.spark.connect.NAReplace.ReplacementR\x0creplacements\x1a\x8d\x01\n\x0bReplacement\x12>\n\told_value\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08oldValue\x12>\n\tnew_value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08newValue"X\n\x04ToDF\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\xfe\x02\n\x12WithColumnsRenamed\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12i\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x37.spark.connect.WithColumnsRenamed.RenameColumnsMapEntryB\x02\x18\x01R\x10renameColumnsMap\x12\x42\n\x07renames\x18\x03 \x03(\x0b\x32(.spark.connect.WithColumnsRenamed.RenameR\x07renames\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x45\n\x06Rename\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x12 \n\x0cnew_col_name\x18\x02 \x01(\tR\nnewColName"w\n\x0bWithColumns\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x07\x61liases\x18\x02 \x03(\x0b\x32\x1f.spark.connect.Expression.AliasR\x07\x61liases"\x86\x01\n\rWithWatermark\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\nevent_time\x18\x02 \x01(\tR\teventTime\x12\'\n\x0f\x64\x65lay_threshold\x18\x03 \x01(\tR\x0e\x64\x65layThreshold"\x84\x01\n\x04Hint\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x39\n\nparameters\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\nparameters"\xc7\x02\n\x07Unpivot\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03ids\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x03ids\x12:\n\x06values\x18\x03 \x01(\x0b\x32\x1d.spark.connect.Unpivot.ValuesH\x00R\x06values\x88\x01\x01\x12\x30\n\x14variable_column_name\x18\x04 \x01(\tR\x12variableColumnName\x12*\n\x11value_column_name\x18\x05 \x01(\tR\x0fvalueColumnName\x1a;\n\x06Values\x12\x31\n\x06values\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x06valuesB\t\n\x07_values"j\n\x08ToSchema\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema"\xcb\x01\n\x17RepartitionByExpression\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x0fpartition_exprs\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0epartitionExprs\x12*\n\x0enum_partitions\x18\x03 \x01(\x05H\x00R\rnumPartitions\x88\x01\x01\x42\x11\n\x0f_num_partitions"\xe8\x01\n\rMapPartitions\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x04\x66unc\x18\x02 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12"\n\nis_barrier\x18\x03 \x01(\x08H\x00R\tisBarrier\x88\x01\x01\x12"\n\nprofile_id\x18\x04 \x01(\x05H\x01R\tprofileId\x88\x01\x01\x42\r\n\x0b_is_barrierB\r\n\x0b_profile_id"\xfb\x04\n\x08GroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12J\n\x13sorting_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x12sortingExpressions\x12<\n\rinitial_input\x18\x05 \x01(\x0b\x32\x17.spark.connect.RelationR\x0cinitialInput\x12[\n\x1cinitial_grouping_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x1ainitialGroupingExpressions\x12;\n\x18is_map_groups_with_state\x18\x07 \x01(\x08H\x00R\x14isMapGroupsWithState\x88\x01\x01\x12$\n\x0boutput_mode\x18\x08 \x01(\tH\x01R\noutputMode\x88\x01\x01\x12&\n\x0ctimeout_conf\x18\t \x01(\tH\x02R\x0btimeoutConf\x88\x01\x01\x42\x1b\n\x19_is_map_groups_with_stateB\x0e\n\x0c_output_modeB\x0f\n\r_timeout_conf"\x8e\x04\n\nCoGroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12W\n\x1ainput_grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18inputGroupingExpressions\x12-\n\x05other\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x05other\x12W\n\x1aother_grouping_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18otherGroupingExpressions\x12\x42\n\x04\x66unc\x18\x05 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12U\n\x19input_sorting_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17inputSortingExpressions\x12U\n\x19other_sorting_expressions\x18\x07 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17otherSortingExpressions"\xe5\x02\n\x16\x41pplyInPandasWithState\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12#\n\routput_schema\x18\x04 \x01(\tR\x0coutputSchema\x12!\n\x0cstate_schema\x18\x05 \x01(\tR\x0bstateSchema\x12\x1f\n\x0boutput_mode\x18\x06 \x01(\tR\noutputMode\x12!\n\x0ctimeout_conf\x18\x07 \x01(\tR\x0btimeoutConf"\xf4\x01\n$CommonInlineUserDefinedTableFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12<\n\x0bpython_udtf\x18\x04 \x01(\x0b\x32\x19.spark.connect.PythonUDTFH\x00R\npythonUdtfB\n\n\x08\x66unction"\xb1\x01\n\nPythonUDTF\x12=\n\x0breturn_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\nreturnType\x88\x01\x01\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVerB\x0e\n\x0c_return_type"\x97\x01\n!CommonInlineUserDefinedDataSource\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12O\n\x12python_data_source\x18\x02 \x01(\x0b\x32\x1f.spark.connect.PythonDataSourceH\x00R\x10pythonDataSourceB\r\n\x0b\x64\x61ta_source"K\n\x10PythonDataSource\x12\x18\n\x07\x63ommand\x18\x01 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x02 \x01(\tR\tpythonVer"\x88\x01\n\x0e\x43ollectMetrics\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x33\n\x07metrics\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07metrics"\x84\x03\n\x05Parse\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x38\n\x06\x66ormat\x18\x02 \x01(\x0e\x32 .spark.connect.Parse.ParseFormatR\x06\x66ormat\x12\x34\n\x06schema\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x06schema\x88\x01\x01\x12;\n\x07options\x18\x04 \x03(\x0b\x32!.spark.connect.Parse.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"X\n\x0bParseFormat\x12\x1c\n\x18PARSE_FORMAT_UNSPECIFIED\x10\x00\x12\x14\n\x10PARSE_FORMAT_CSV\x10\x01\x12\x15\n\x11PARSE_FORMAT_JSON\x10\x02\x42\t\n\x07_schema"\xdb\x03\n\x08\x41sOfJoin\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12\x37\n\nleft_as_of\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08leftAsOf\x12\x39\n\x0bright_as_of\x18\x04 \x01(\x0b\x32\x19.spark.connect.ExpressionR\trightAsOf\x12\x36\n\tjoin_expr\x18\x05 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08joinExpr\x12#\n\rusing_columns\x18\x06 \x03(\tR\x0cusingColumns\x12\x1b\n\tjoin_type\x18\x07 \x01(\tR\x08joinType\x12\x37\n\ttolerance\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\ttolerance\x12.\n\x13\x61llow_exact_matches\x18\t \x01(\x08R\x11\x61llowExactMatches\x12\x1c\n\tdirection\x18\n \x01(\tR\tdirectionB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto"\xa3\x1b\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12*\n\x05to_df\x18\x12 \x01(\x0b\x32\x13.spark.connect.ToDFH\x00R\x04toDf\x12U\n\x14with_columns_renamed\x18\x13 \x01(\x0b\x32!.spark.connect.WithColumnsRenamedH\x00R\x12withColumnsRenamed\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12)\n\x04tail\x18\x16 \x01(\x0b\x32\x13.spark.connect.TailH\x00R\x04tail\x12?\n\x0cwith_columns\x18\x17 \x01(\x0b\x32\x1a.spark.connect.WithColumnsH\x00R\x0bwithColumns\x12)\n\x04hint\x18\x18 \x01(\x0b\x32\x13.spark.connect.HintH\x00R\x04hint\x12\x32\n\x07unpivot\x18\x19 \x01(\x0b\x32\x16.spark.connect.UnpivotH\x00R\x07unpivot\x12\x36\n\tto_schema\x18\x1a \x01(\x0b\x32\x17.spark.connect.ToSchemaH\x00R\x08toSchema\x12\x64\n\x19repartition_by_expression\x18\x1b \x01(\x0b\x32&.spark.connect.RepartitionByExpressionH\x00R\x17repartitionByExpression\x12\x45\n\x0emap_partitions\x18\x1c \x01(\x0b\x32\x1c.spark.connect.MapPartitionsH\x00R\rmapPartitions\x12H\n\x0f\x63ollect_metrics\x18\x1d \x01(\x0b\x32\x1d.spark.connect.CollectMetricsH\x00R\x0e\x63ollectMetrics\x12,\n\x05parse\x18\x1e \x01(\x0b\x32\x14.spark.connect.ParseH\x00R\x05parse\x12\x36\n\tgroup_map\x18\x1f \x01(\x0b\x32\x17.spark.connect.GroupMapH\x00R\x08groupMap\x12=\n\x0c\x63o_group_map\x18 \x01(\x0b\x32\x19.spark.connect.CoGroupMapH\x00R\ncoGroupMap\x12\x45\n\x0ewith_watermark\x18! \x01(\x0b\x32\x1c.spark.connect.WithWatermarkH\x00R\rwithWatermark\x12\x63\n\x1a\x61pply_in_pandas_with_state\x18" \x01(\x0b\x32%.spark.connect.ApplyInPandasWithStateH\x00R\x16\x61pplyInPandasWithState\x12<\n\x0bhtml_string\x18# \x01(\x0b\x32\x19.spark.connect.HtmlStringH\x00R\nhtmlString\x12X\n\x15\x63\x61\x63hed_local_relation\x18$ \x01(\x0b\x32".spark.connect.CachedLocalRelationH\x00R\x13\x63\x61\x63hedLocalRelation\x12[\n\x16\x63\x61\x63hed_remote_relation\x18% \x01(\x0b\x32#.spark.connect.CachedRemoteRelationH\x00R\x14\x63\x61\x63hedRemoteRelation\x12\x8e\x01\n)common_inline_user_defined_table_function\x18& \x01(\x0b\x32\x33.spark.connect.CommonInlineUserDefinedTableFunctionH\x00R$commonInlineUserDefinedTableFunction\x12\x37\n\nas_of_join\x18\' \x01(\x0b\x32\x17.spark.connect.AsOfJoinH\x00R\x08\x61sOfJoin\x12\x85\x01\n&common_inline_user_defined_data_source\x18( \x01(\x0b\x32\x30.spark.connect.CommonInlineUserDefinedDataSourceH\x00R!commonInlineUserDefinedDataSource\x12\x45\n\x0ewith_relations\x18) \x01(\x0b\x32\x1c.spark.connect.WithRelationsH\x00R\rwithRelations\x12\x38\n\ttranspose\x18* \x01(\x0b\x32\x18.spark.connect.TransposeH\x00R\ttranspose\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x30\n\x07\x64rop_na\x18[ \x01(\x0b\x32\x15.spark.connect.NADropH\x00R\x06\x64ropNa\x12\x34\n\x07replace\x18\\ \x01(\x0b\x32\x18.spark.connect.NAReplaceH\x00R\x07replace\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x39\n\x08\x64\x65scribe\x18\x66 \x01(\x0b\x32\x1b.spark.connect.StatDescribeH\x00R\x08\x64\x65scribe\x12*\n\x03\x63ov\x18g \x01(\x0b\x32\x16.spark.connect.StatCovH\x00R\x03\x63ov\x12-\n\x04\x63orr\x18h \x01(\x0b\x32\x17.spark.connect.StatCorrH\x00R\x04\x63orr\x12L\n\x0f\x61pprox_quantile\x18i \x01(\x0b\x32!.spark.connect.StatApproxQuantileH\x00R\x0e\x61pproxQuantile\x12=\n\nfreq_items\x18j \x01(\x0b\x32\x1c.spark.connect.StatFreqItemsH\x00R\tfreqItems\x12:\n\tsample_by\x18k \x01(\x0b\x32\x1b.spark.connect.StatSampleByH\x00R\x08sampleBy\x12\x33\n\x07\x63\x61talog\x18\xc8\x01 \x01(\x0b\x32\x16.spark.connect.CatalogH\x00R\x07\x63\x61talog\x12\x35\n\textension\x18\xe6\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"\x8e\x01\n\x0eRelationCommon\x12#\n\x0bsource_info\x18\x01 \x01(\tB\x02\x18\x01R\nsourceInfo\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x12-\n\x06origin\x18\x03 \x01(\x0b\x32\x15.spark.connect.OriginR\x06originB\n\n\x08_plan_id"\xde\x03\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query\x12\x34\n\x04\x61rgs\x18\x02 \x03(\x0b\x32\x1c.spark.connect.SQL.ArgsEntryB\x02\x18\x01R\x04\x61rgs\x12@\n\x08pos_args\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralB\x02\x18\x01R\x07posArgs\x12O\n\x0fnamed_arguments\x18\x04 \x03(\x0b\x32&.spark.connect.SQL.NamedArgumentsEntryR\x0enamedArguments\x12>\n\rpos_arguments\x18\x05 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0cposArguments\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01\x1a\\\n\x13NamedArgumentsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value:\x02\x38\x01"u\n\rWithRelations\x12+\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04root\x12\x37\n\nreferences\x18\x02 \x03(\x0b\x32\x17.spark.connect.RelationR\nreferences"\x97\x05\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x12!\n\x0cis_streaming\x18\x03 \x01(\x08R\x0bisStreaming\x1a\xc0\x01\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x45\n\x07options\x18\x02 \x03(\x0b\x32+.spark.connect.Read.NamedTable.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x95\x02\n\nDataSource\x12\x1b\n\x06\x66ormat\x18\x01 \x01(\tH\x00R\x06\x66ormat\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x12\x14\n\x05paths\x18\x04 \x03(\tR\x05paths\x12\x1e\n\npredicates\x18\x05 \x03(\tR\npredicates\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_formatB\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x95\x05\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns\x12K\n\x0ejoin_data_type\x18\x06 \x01(\x0b\x32 .spark.connect.Join.JoinDataTypeH\x00R\x0cjoinDataType\x88\x01\x01\x1a\\\n\x0cJoinDataType\x12$\n\x0eis_left_struct\x18\x01 \x01(\x08R\x0cisLeftStruct\x12&\n\x0fis_right_struct\x18\x02 \x01(\x08R\risRightStruct"\xd0\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06\x12\x13\n\x0fJOIN_TYPE_CROSS\x10\x07\x42\x11\n\x0f_join_data_type"\xdf\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01\x12\x37\n\x15\x61llow_missing_columns\x18\x06 \x01(\x08H\x02R\x13\x61llowMissingColumns\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_nameB\x18\n\x16_allow_missing_columns"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"K\n\x04Tail\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"\xfe\x05\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x41\n\ngroup_type\x18\x02 \x01(\x0e\x32".spark.connect.Aggregate.GroupTypeR\tgroupType\x12L\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12N\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x14\x61ggregateExpressions\x12\x34\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.spark.connect.Aggregate.PivotR\x05pivot\x12J\n\rgrouping_sets\x18\x06 \x03(\x0b\x32%.spark.connect.Aggregate.GroupingSetsR\x0cgroupingSets\x1ao\n\x05Pivot\x12+\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1aL\n\x0cGroupingSets\x12<\n\x0cgrouping_set\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0bgroupingSet"\x9f\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04\x12\x1c\n\x18GROUP_TYPE_GROUPING_SETS\x10\x05"\xa0\x01\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x05order\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\x05order\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x42\x0c\n\n_is_global"\x8d\x01\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x33\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07\x63olumns\x12!\n\x0c\x63olumn_names\x18\x03 \x03(\tR\x0b\x63olumnNames"\xf0\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x12.\n\x10within_watermark\x18\x04 \x01(\x08H\x01R\x0fwithinWatermark\x88\x01\x01\x42\x16\n\x14_all_columns_as_keysB\x13\n\x11_within_watermark"Y\n\rLocalRelation\x12\x17\n\x04\x64\x61ta\x18\x01 \x01(\x0cH\x00R\x04\x64\x61ta\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x42\x07\n\x05_dataB\t\n\x07_schema"H\n\x13\x43\x61\x63hedLocalRelation\x12\x12\n\x04hash\x18\x03 \x01(\tR\x04hashJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03R\x06userIdR\tsessionId"7\n\x14\x43\x61\x63hedRemoteRelation\x12\x1f\n\x0brelation_id\x18\x01 \x01(\tR\nrelationId"\x91\x02\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x12/\n\x13\x64\x65terministic_order\x18\x06 \x01(\x08R\x12\x64\x65terministicOrderB\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8e\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"r\n\nHtmlString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"Q\n\x0cStatDescribe\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"`\n\x07StatCov\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x89\x01\n\x08StatCorr\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2\x12\x1b\n\x06method\x18\x04 \x01(\tH\x00R\x06method\x88\x01\x01\x42\t\n\x07_method"\xa4\x01\n\x12StatApproxQuantile\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12$\n\rprobabilities\x18\x03 \x03(\x01R\rprobabilities\x12%\n\x0erelative_error\x18\x04 \x01(\x01R\rrelativeError"}\n\rStatFreqItems\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x1d\n\x07support\x18\x03 \x01(\x01H\x00R\x07support\x88\x01\x01\x42\n\n\x08_support"\xb5\x02\n\x0cStatSampleBy\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03\x63ol\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x42\n\tfractions\x18\x03 \x03(\x0b\x32$.spark.connect.StatSampleBy.FractionR\tfractions\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x00R\x04seed\x88\x01\x01\x1a\x63\n\x08\x46raction\x12;\n\x07stratum\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x07stratum\x12\x1a\n\x08\x66raction\x18\x02 \x01(\x01R\x08\x66ractionB\x07\n\x05_seed"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x86\x01\n\x06NADrop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\'\n\rmin_non_nulls\x18\x03 \x01(\x05H\x00R\x0bminNonNulls\x88\x01\x01\x42\x10\n\x0e_min_non_nulls"\xa8\x02\n\tNAReplace\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12H\n\x0creplacements\x18\x03 \x03(\x0b\x32$.spark.connect.NAReplace.ReplacementR\x0creplacements\x1a\x8d\x01\n\x0bReplacement\x12>\n\told_value\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08oldValue\x12>\n\tnew_value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08newValue"X\n\x04ToDF\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\xfe\x02\n\x12WithColumnsRenamed\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12i\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x37.spark.connect.WithColumnsRenamed.RenameColumnsMapEntryB\x02\x18\x01R\x10renameColumnsMap\x12\x42\n\x07renames\x18\x03 \x03(\x0b\x32(.spark.connect.WithColumnsRenamed.RenameR\x07renames\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x45\n\x06Rename\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x12 \n\x0cnew_col_name\x18\x02 \x01(\tR\nnewColName"w\n\x0bWithColumns\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x07\x61liases\x18\x02 \x03(\x0b\x32\x1f.spark.connect.Expression.AliasR\x07\x61liases"\x86\x01\n\rWithWatermark\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\nevent_time\x18\x02 \x01(\tR\teventTime\x12\'\n\x0f\x64\x65lay_threshold\x18\x03 \x01(\tR\x0e\x64\x65layThreshold"\x84\x01\n\x04Hint\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x39\n\nparameters\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\nparameters"\xc7\x02\n\x07Unpivot\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03ids\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x03ids\x12:\n\x06values\x18\x03 \x01(\x0b\x32\x1d.spark.connect.Unpivot.ValuesH\x00R\x06values\x88\x01\x01\x12\x30\n\x14variable_column_name\x18\x04 \x01(\tR\x12variableColumnName\x12*\n\x11value_column_name\x18\x05 \x01(\tR\x0fvalueColumnName\x1a;\n\x06Values\x12\x31\n\x06values\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x06valuesB\t\n\x07_values"z\n\tTranspose\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\rindex_columns\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0cindexColumns"j\n\x08ToSchema\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema"\xcb\x01\n\x17RepartitionByExpression\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x0fpartition_exprs\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0epartitionExprs\x12*\n\x0enum_partitions\x18\x03 \x01(\x05H\x00R\rnumPartitions\x88\x01\x01\x42\x11\n\x0f_num_partitions"\xe8\x01\n\rMapPartitions\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x04\x66unc\x18\x02 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12"\n\nis_barrier\x18\x03 \x01(\x08H\x00R\tisBarrier\x88\x01\x01\x12"\n\nprofile_id\x18\x04 \x01(\x05H\x01R\tprofileId\x88\x01\x01\x42\r\n\x0b_is_barrierB\r\n\x0b_profile_id"\xfb\x04\n\x08GroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12J\n\x13sorting_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x12sortingExpressions\x12<\n\rinitial_input\x18\x05 \x01(\x0b\x32\x17.spark.connect.RelationR\x0cinitialInput\x12[\n\x1cinitial_grouping_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x1ainitialGroupingExpressions\x12;\n\x18is_map_groups_with_state\x18\x07 \x01(\x08H\x00R\x14isMapGroupsWithState\x88\x01\x01\x12$\n\x0boutput_mode\x18\x08 \x01(\tH\x01R\noutputMode\x88\x01\x01\x12&\n\x0ctimeout_conf\x18\t \x01(\tH\x02R\x0btimeoutConf\x88\x01\x01\x42\x1b\n\x19_is_map_groups_with_stateB\x0e\n\x0c_output_modeB\x0f\n\r_timeout_conf"\x8e\x04\n\nCoGroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12W\n\x1ainput_grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18inputGroupingExpressions\x12-\n\x05other\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x05other\x12W\n\x1aother_grouping_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18otherGroupingExpressions\x12\x42\n\x04\x66unc\x18\x05 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12U\n\x19input_sorting_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17inputSortingExpressions\x12U\n\x19other_sorting_expressions\x18\x07 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17otherSortingExpressions"\xe5\x02\n\x16\x41pplyInPandasWithState\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12#\n\routput_schema\x18\x04 \x01(\tR\x0coutputSchema\x12!\n\x0cstate_schema\x18\x05 \x01(\tR\x0bstateSchema\x12\x1f\n\x0boutput_mode\x18\x06 \x01(\tR\noutputMode\x12!\n\x0ctimeout_conf\x18\x07 \x01(\tR\x0btimeoutConf"\xf4\x01\n$CommonInlineUserDefinedTableFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12<\n\x0bpython_udtf\x18\x04 \x01(\x0b\x32\x19.spark.connect.PythonUDTFH\x00R\npythonUdtfB\n\n\x08\x66unction"\xb1\x01\n\nPythonUDTF\x12=\n\x0breturn_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\nreturnType\x88\x01\x01\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVerB\x0e\n\x0c_return_type"\x97\x01\n!CommonInlineUserDefinedDataSource\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12O\n\x12python_data_source\x18\x02 \x01(\x0b\x32\x1f.spark.connect.PythonDataSourceH\x00R\x10pythonDataSourceB\r\n\x0b\x64\x61ta_source"K\n\x10PythonDataSource\x12\x18\n\x07\x63ommand\x18\x01 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x02 \x01(\tR\tpythonVer"\x88\x01\n\x0e\x43ollectMetrics\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x33\n\x07metrics\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07metrics"\x84\x03\n\x05Parse\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x38\n\x06\x66ormat\x18\x02 \x01(\x0e\x32 .spark.connect.Parse.ParseFormatR\x06\x66ormat\x12\x34\n\x06schema\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x06schema\x88\x01\x01\x12;\n\x07options\x18\x04 \x03(\x0b\x32!.spark.connect.Parse.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"X\n\x0bParseFormat\x12\x1c\n\x18PARSE_FORMAT_UNSPECIFIED\x10\x00\x12\x14\n\x10PARSE_FORMAT_CSV\x10\x01\x12\x15\n\x11PARSE_FORMAT_JSON\x10\x02\x42\t\n\x07_schema"\xdb\x03\n\x08\x41sOfJoin\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12\x37\n\nleft_as_of\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08leftAsOf\x12\x39\n\x0bright_as_of\x18\x04 \x01(\x0b\x32\x19.spark.connect.ExpressionR\trightAsOf\x12\x36\n\tjoin_expr\x18\x05 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08joinExpr\x12#\n\rusing_columns\x18\x06 \x03(\tR\x0cusingColumns\x12\x1b\n\tjoin_type\x18\x07 \x01(\tR\x08joinType\x12\x37\n\ttolerance\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\ttolerance\x12.\n\x13\x61llow_exact_matches\x18\t \x01(\x08R\x11\x61llowExactMatches\x12\x1c\n\tdirection\x18\n \x01(\tR\tdirectionB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -69,153 +69,155 @@ _PARSE_OPTIONSENTRY._options = None _PARSE_OPTIONSENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 193 - _RELATION._serialized_end = 3626 - _UNKNOWN._serialized_start = 3628 - _UNKNOWN._serialized_end = 3637 - _RELATIONCOMMON._serialized_start = 3640 - _RELATIONCOMMON._serialized_end = 3782 - _SQL._serialized_start = 3785 - _SQL._serialized_end = 4263 - _SQL_ARGSENTRY._serialized_start = 4079 - _SQL_ARGSENTRY._serialized_end = 4169 - _SQL_NAMEDARGUMENTSENTRY._serialized_start = 4171 - _SQL_NAMEDARGUMENTSENTRY._serialized_end = 4263 - _WITHRELATIONS._serialized_start = 4265 - _WITHRELATIONS._serialized_end = 4382 - _READ._serialized_start = 4385 - _READ._serialized_end = 5048 - _READ_NAMEDTABLE._serialized_start = 4563 - _READ_NAMEDTABLE._serialized_end = 4755 - _READ_NAMEDTABLE_OPTIONSENTRY._serialized_start = 4697 - _READ_NAMEDTABLE_OPTIONSENTRY._serialized_end = 4755 - _READ_DATASOURCE._serialized_start = 4758 - _READ_DATASOURCE._serialized_end = 5035 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 4697 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 4755 - _PROJECT._serialized_start = 5050 - _PROJECT._serialized_end = 5167 - _FILTER._serialized_start = 5169 - _FILTER._serialized_end = 5281 - _JOIN._serialized_start = 5284 - _JOIN._serialized_end = 5945 - _JOIN_JOINDATATYPE._serialized_start = 5623 - _JOIN_JOINDATATYPE._serialized_end = 5715 - _JOIN_JOINTYPE._serialized_start = 5718 - _JOIN_JOINTYPE._serialized_end = 5926 - _SETOPERATION._serialized_start = 5948 - _SETOPERATION._serialized_end = 6427 - _SETOPERATION_SETOPTYPE._serialized_start = 6264 - _SETOPERATION_SETOPTYPE._serialized_end = 6378 - _LIMIT._serialized_start = 6429 - _LIMIT._serialized_end = 6505 - _OFFSET._serialized_start = 6507 - _OFFSET._serialized_end = 6586 - _TAIL._serialized_start = 6588 - _TAIL._serialized_end = 6663 - _AGGREGATE._serialized_start = 6666 - _AGGREGATE._serialized_end = 7432 - _AGGREGATE_PIVOT._serialized_start = 7081 - _AGGREGATE_PIVOT._serialized_end = 7192 - _AGGREGATE_GROUPINGSETS._serialized_start = 7194 - _AGGREGATE_GROUPINGSETS._serialized_end = 7270 - _AGGREGATE_GROUPTYPE._serialized_start = 7273 - _AGGREGATE_GROUPTYPE._serialized_end = 7432 - _SORT._serialized_start = 7435 - _SORT._serialized_end = 7595 - _DROP._serialized_start = 7598 - _DROP._serialized_end = 7739 - _DEDUPLICATE._serialized_start = 7742 - _DEDUPLICATE._serialized_end = 7982 - _LOCALRELATION._serialized_start = 7984 - _LOCALRELATION._serialized_end = 8073 - _CACHEDLOCALRELATION._serialized_start = 8075 - _CACHEDLOCALRELATION._serialized_end = 8147 - _CACHEDREMOTERELATION._serialized_start = 8149 - _CACHEDREMOTERELATION._serialized_end = 8204 - _SAMPLE._serialized_start = 8207 - _SAMPLE._serialized_end = 8480 - _RANGE._serialized_start = 8483 - _RANGE._serialized_end = 8628 - _SUBQUERYALIAS._serialized_start = 8630 - _SUBQUERYALIAS._serialized_end = 8744 - _REPARTITION._serialized_start = 8747 - _REPARTITION._serialized_end = 8889 - _SHOWSTRING._serialized_start = 8892 - _SHOWSTRING._serialized_end = 9034 - _HTMLSTRING._serialized_start = 9036 - _HTMLSTRING._serialized_end = 9150 - _STATSUMMARY._serialized_start = 9152 - _STATSUMMARY._serialized_end = 9244 - _STATDESCRIBE._serialized_start = 9246 - _STATDESCRIBE._serialized_end = 9327 - _STATCROSSTAB._serialized_start = 9329 - _STATCROSSTAB._serialized_end = 9430 - _STATCOV._serialized_start = 9432 - _STATCOV._serialized_end = 9528 - _STATCORR._serialized_start = 9531 - _STATCORR._serialized_end = 9668 - _STATAPPROXQUANTILE._serialized_start = 9671 - _STATAPPROXQUANTILE._serialized_end = 9835 - _STATFREQITEMS._serialized_start = 9837 - _STATFREQITEMS._serialized_end = 9962 - _STATSAMPLEBY._serialized_start = 9965 - _STATSAMPLEBY._serialized_end = 10274 - _STATSAMPLEBY_FRACTION._serialized_start = 10166 - _STATSAMPLEBY_FRACTION._serialized_end = 10265 - _NAFILL._serialized_start = 10277 - _NAFILL._serialized_end = 10411 - _NADROP._serialized_start = 10414 - _NADROP._serialized_end = 10548 - _NAREPLACE._serialized_start = 10551 - _NAREPLACE._serialized_end = 10847 - _NAREPLACE_REPLACEMENT._serialized_start = 10706 - _NAREPLACE_REPLACEMENT._serialized_end = 10847 - _TODF._serialized_start = 10849 - _TODF._serialized_end = 10937 - _WITHCOLUMNSRENAMED._serialized_start = 10940 - _WITHCOLUMNSRENAMED._serialized_end = 11322 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 11184 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 11251 - _WITHCOLUMNSRENAMED_RENAME._serialized_start = 11253 - _WITHCOLUMNSRENAMED_RENAME._serialized_end = 11322 - _WITHCOLUMNS._serialized_start = 11324 - _WITHCOLUMNS._serialized_end = 11443 - _WITHWATERMARK._serialized_start = 11446 - _WITHWATERMARK._serialized_end = 11580 - _HINT._serialized_start = 11583 - _HINT._serialized_end = 11715 - _UNPIVOT._serialized_start = 11718 - _UNPIVOT._serialized_end = 12045 - _UNPIVOT_VALUES._serialized_start = 11975 - _UNPIVOT_VALUES._serialized_end = 12034 - _TOSCHEMA._serialized_start = 12047 - _TOSCHEMA._serialized_end = 12153 - _REPARTITIONBYEXPRESSION._serialized_start = 12156 - _REPARTITIONBYEXPRESSION._serialized_end = 12359 - _MAPPARTITIONS._serialized_start = 12362 - _MAPPARTITIONS._serialized_end = 12594 - _GROUPMAP._serialized_start = 12597 - _GROUPMAP._serialized_end = 13232 - _COGROUPMAP._serialized_start = 13235 - _COGROUPMAP._serialized_end = 13761 - _APPLYINPANDASWITHSTATE._serialized_start = 13764 - _APPLYINPANDASWITHSTATE._serialized_end = 14121 - _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_start = 14124 - _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_end = 14368 - _PYTHONUDTF._serialized_start = 14371 - _PYTHONUDTF._serialized_end = 14548 - _COMMONINLINEUSERDEFINEDDATASOURCE._serialized_start = 14551 - _COMMONINLINEUSERDEFINEDDATASOURCE._serialized_end = 14702 - _PYTHONDATASOURCE._serialized_start = 14704 - _PYTHONDATASOURCE._serialized_end = 14779 - _COLLECTMETRICS._serialized_start = 14782 - _COLLECTMETRICS._serialized_end = 14918 - _PARSE._serialized_start = 14921 - _PARSE._serialized_end = 15309 - _PARSE_OPTIONSENTRY._serialized_start = 4697 - _PARSE_OPTIONSENTRY._serialized_end = 4755 - _PARSE_PARSEFORMAT._serialized_start = 15210 - _PARSE_PARSEFORMAT._serialized_end = 15298 - _ASOFJOIN._serialized_start = 15312 - _ASOFJOIN._serialized_end = 15787 + _RELATION._serialized_end = 3684 + _UNKNOWN._serialized_start = 3686 + _UNKNOWN._serialized_end = 3695 + _RELATIONCOMMON._serialized_start = 3698 + _RELATIONCOMMON._serialized_end = 3840 + _SQL._serialized_start = 3843 + _SQL._serialized_end = 4321 + _SQL_ARGSENTRY._serialized_start = 4137 + _SQL_ARGSENTRY._serialized_end = 4227 + _SQL_NAMEDARGUMENTSENTRY._serialized_start = 4229 + _SQL_NAMEDARGUMENTSENTRY._serialized_end = 4321 + _WITHRELATIONS._serialized_start = 4323 + _WITHRELATIONS._serialized_end = 4440 + _READ._serialized_start = 4443 + _READ._serialized_end = 5106 + _READ_NAMEDTABLE._serialized_start = 4621 + _READ_NAMEDTABLE._serialized_end = 4813 + _READ_NAMEDTABLE_OPTIONSENTRY._serialized_start = 4755 + _READ_NAMEDTABLE_OPTIONSENTRY._serialized_end = 4813 + _READ_DATASOURCE._serialized_start = 4816 + _READ_DATASOURCE._serialized_end = 5093 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 4755 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 4813 + _PROJECT._serialized_start = 5108 + _PROJECT._serialized_end = 5225 + _FILTER._serialized_start = 5227 + _FILTER._serialized_end = 5339 + _JOIN._serialized_start = 5342 + _JOIN._serialized_end = 6003 + _JOIN_JOINDATATYPE._serialized_start = 5681 + _JOIN_JOINDATATYPE._serialized_end = 5773 + _JOIN_JOINTYPE._serialized_start = 5776 + _JOIN_JOINTYPE._serialized_end = 5984 + _SETOPERATION._serialized_start = 6006 + _SETOPERATION._serialized_end = 6485 + _SETOPERATION_SETOPTYPE._serialized_start = 6322 + _SETOPERATION_SETOPTYPE._serialized_end = 6436 + _LIMIT._serialized_start = 6487 + _LIMIT._serialized_end = 6563 + _OFFSET._serialized_start = 6565 + _OFFSET._serialized_end = 6644 + _TAIL._serialized_start = 6646 + _TAIL._serialized_end = 6721 + _AGGREGATE._serialized_start = 6724 + _AGGREGATE._serialized_end = 7490 + _AGGREGATE_PIVOT._serialized_start = 7139 + _AGGREGATE_PIVOT._serialized_end = 7250 + _AGGREGATE_GROUPINGSETS._serialized_start = 7252 + _AGGREGATE_GROUPINGSETS._serialized_end = 7328 + _AGGREGATE_GROUPTYPE._serialized_start = 7331 + _AGGREGATE_GROUPTYPE._serialized_end = 7490 + _SORT._serialized_start = 7493 + _SORT._serialized_end = 7653 + _DROP._serialized_start = 7656 + _DROP._serialized_end = 7797 + _DEDUPLICATE._serialized_start = 7800 + _DEDUPLICATE._serialized_end = 8040 + _LOCALRELATION._serialized_start = 8042 + _LOCALRELATION._serialized_end = 8131 + _CACHEDLOCALRELATION._serialized_start = 8133 + _CACHEDLOCALRELATION._serialized_end = 8205 + _CACHEDREMOTERELATION._serialized_start = 8207 + _CACHEDREMOTERELATION._serialized_end = 8262 + _SAMPLE._serialized_start = 8265 + _SAMPLE._serialized_end = 8538 + _RANGE._serialized_start = 8541 + _RANGE._serialized_end = 8686 + _SUBQUERYALIAS._serialized_start = 8688 + _SUBQUERYALIAS._serialized_end = 8802 + _REPARTITION._serialized_start = 8805 + _REPARTITION._serialized_end = 8947 + _SHOWSTRING._serialized_start = 8950 + _SHOWSTRING._serialized_end = 9092 + _HTMLSTRING._serialized_start = 9094 + _HTMLSTRING._serialized_end = 9208 + _STATSUMMARY._serialized_start = 9210 + _STATSUMMARY._serialized_end = 9302 + _STATDESCRIBE._serialized_start = 9304 + _STATDESCRIBE._serialized_end = 9385 + _STATCROSSTAB._serialized_start = 9387 + _STATCROSSTAB._serialized_end = 9488 + _STATCOV._serialized_start = 9490 + _STATCOV._serialized_end = 9586 + _STATCORR._serialized_start = 9589 + _STATCORR._serialized_end = 9726 + _STATAPPROXQUANTILE._serialized_start = 9729 + _STATAPPROXQUANTILE._serialized_end = 9893 + _STATFREQITEMS._serialized_start = 9895 + _STATFREQITEMS._serialized_end = 10020 + _STATSAMPLEBY._serialized_start = 10023 + _STATSAMPLEBY._serialized_end = 10332 + _STATSAMPLEBY_FRACTION._serialized_start = 10224 + _STATSAMPLEBY_FRACTION._serialized_end = 10323 + _NAFILL._serialized_start = 10335 + _NAFILL._serialized_end = 10469 + _NADROP._serialized_start = 10472 + _NADROP._serialized_end = 10606 + _NAREPLACE._serialized_start = 10609 + _NAREPLACE._serialized_end = 10905 + _NAREPLACE_REPLACEMENT._serialized_start = 10764 + _NAREPLACE_REPLACEMENT._serialized_end = 10905 + _TODF._serialized_start = 10907 + _TODF._serialized_end = 10995 + _WITHCOLUMNSRENAMED._serialized_start = 10998 + _WITHCOLUMNSRENAMED._serialized_end = 11380 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 11242 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 11309 + _WITHCOLUMNSRENAMED_RENAME._serialized_start = 11311 + _WITHCOLUMNSRENAMED_RENAME._serialized_end = 11380 + _WITHCOLUMNS._serialized_start = 11382 + _WITHCOLUMNS._serialized_end = 11501 + _WITHWATERMARK._serialized_start = 11504 + _WITHWATERMARK._serialized_end = 11638 + _HINT._serialized_start = 11641 + _HINT._serialized_end = 11773 + _UNPIVOT._serialized_start = 11776 + _UNPIVOT._serialized_end = 12103 + _UNPIVOT_VALUES._serialized_start = 12033 + _UNPIVOT_VALUES._serialized_end = 12092 + _TRANSPOSE._serialized_start = 12105 + _TRANSPOSE._serialized_end = 12227 + _TOSCHEMA._serialized_start = 12229 + _TOSCHEMA._serialized_end = 12335 + _REPARTITIONBYEXPRESSION._serialized_start = 12338 + _REPARTITIONBYEXPRESSION._serialized_end = 12541 + _MAPPARTITIONS._serialized_start = 12544 + _MAPPARTITIONS._serialized_end = 12776 + _GROUPMAP._serialized_start = 12779 + _GROUPMAP._serialized_end = 13414 + _COGROUPMAP._serialized_start = 13417 + _COGROUPMAP._serialized_end = 13943 + _APPLYINPANDASWITHSTATE._serialized_start = 13946 + _APPLYINPANDASWITHSTATE._serialized_end = 14303 + _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_start = 14306 + _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_end = 14550 + _PYTHONUDTF._serialized_start = 14553 + _PYTHONUDTF._serialized_end = 14730 + _COMMONINLINEUSERDEFINEDDATASOURCE._serialized_start = 14733 + _COMMONINLINEUSERDEFINEDDATASOURCE._serialized_end = 14884 + _PYTHONDATASOURCE._serialized_start = 14886 + _PYTHONDATASOURCE._serialized_end = 14961 + _COLLECTMETRICS._serialized_start = 14964 + _COLLECTMETRICS._serialized_end = 15100 + _PARSE._serialized_start = 15103 + _PARSE._serialized_end = 15491 + _PARSE_OPTIONSENTRY._serialized_start = 4755 + _PARSE_OPTIONSENTRY._serialized_end = 4813 + _PARSE_PARSEFORMAT._serialized_start = 15392 + _PARSE_PARSEFORMAT._serialized_end = 15480 + _ASOFJOIN._serialized_start = 15494 + _ASOFJOIN._serialized_end = 15969 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 864803fd33084..b1cd2e184d085 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -104,6 +104,7 @@ class Relation(google.protobuf.message.Message): AS_OF_JOIN_FIELD_NUMBER: builtins.int COMMON_INLINE_USER_DEFINED_DATA_SOURCE_FIELD_NUMBER: builtins.int WITH_RELATIONS_FIELD_NUMBER: builtins.int + TRANSPOSE_FIELD_NUMBER: builtins.int FILL_NA_FIELD_NUMBER: builtins.int DROP_NA_FIELD_NUMBER: builtins.int REPLACE_FIELD_NUMBER: builtins.int @@ -205,6 +206,8 @@ class Relation(google.protobuf.message.Message): @property def with_relations(self) -> global___WithRelations: ... @property + def transpose(self) -> global___Transpose: ... + @property def fill_na(self) -> global___NAFill: """NA functions""" @property @@ -284,6 +287,7 @@ class Relation(google.protobuf.message.Message): common_inline_user_defined_data_source: global___CommonInlineUserDefinedDataSource | None = ..., with_relations: global___WithRelations | None = ..., + transpose: global___Transpose | None = ..., fill_na: global___NAFill | None = ..., drop_na: global___NADrop | None = ..., replace: global___NAReplace | None = ..., @@ -402,6 +406,8 @@ class Relation(google.protobuf.message.Message): b"to_df", "to_schema", b"to_schema", + "transpose", + b"transpose", "unknown", b"unknown", "unpivot", @@ -519,6 +525,8 @@ class Relation(google.protobuf.message.Message): b"to_df", "to_schema", b"to_schema", + "transpose", + b"transpose", "unknown", b"unknown", "unpivot", @@ -577,6 +585,7 @@ class Relation(google.protobuf.message.Message): "as_of_join", "common_inline_user_defined_data_source", "with_relations", + "transpose", "fill_na", "drop_na", "replace", @@ -3141,6 +3150,47 @@ class Unpivot(google.protobuf.message.Message): global___Unpivot = Unpivot +class Transpose(google.protobuf.message.Message): + """Transpose a DataFrame, switching rows to columns. + Transforms the DataFrame such that the values in the specified index column + become the new columns of the DataFrame. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + INDEX_COLUMNS_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: + """(Required) The input relation.""" + @property + def index_columns( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + pyspark.sql.connect.proto.expressions_pb2.Expression + ]: + """(Optional) A list of columns that will be treated as the indices. + Only single column is supported now. + """ + def __init__( + self, + *, + input: global___Relation | None = ..., + index_columns: collections.abc.Iterable[ + pyspark.sql.connect.proto.expressions_pb2.Expression + ] + | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["input", b"input"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal["index_columns", b"index_columns", "input", b"input"], + ) -> None: ... + +global___Transpose = Transpose + class ToSchema(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index e5246e893f658..cacb479229bb7 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -19,6 +19,7 @@ check_dependencies(__name__) +import json import threading import os import warnings @@ -200,6 +201,26 @@ def enableHiveSupport(self) -> "SparkSession.Builder": ) def _apply_options(self, session: "SparkSession") -> None: + init_opts = {} + for i in range(int(os.environ.get("PYSPARK_REMOTE_INIT_CONF_LEN", "0"))): + init_opts = json.loads(os.environ[f"PYSPARK_REMOTE_INIT_CONF_{i}"]) + + with self._lock: + for k, v in init_opts.items(): + # the options are applied after session creation, + # so following options always take no effect + if k not in [ + "spark.remote", + "spark.master", + ] and k.startswith("spark.sql."): + # Only attempts to set Spark SQL configurations. + # If the configurations are static, it might throw an exception so + # simply ignore it for now. + try: + session.conf.set(k, v) + except Exception: + pass + with self._lock: for k, v in self._options.items(): # the options are applied after session creation, @@ -993,10 +1014,17 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: session = PySparkSession._instantiatedSession if session is None or session._sc._jsc is None: + init_opts = {} + for i in range(int(os.environ.get("PYSPARK_REMOTE_INIT_CONF_LEN", "0"))): + init_opts = json.loads(os.environ[f"PYSPARK_REMOTE_INIT_CONF_{i}"]) + init_opts.update(opts) + opts = init_opts + # Configurations to be overwritten overwrite_conf = opts overwrite_conf["spark.master"] = master overwrite_conf["spark.local.connect"] = "1" + os.environ["SPARK_LOCAL_CONNECT"] = "1" # Configurations to be set if unset. default_conf = {"spark.plugins": "org.apache.spark.sql.connect.SparkConnectPlugin"} @@ -1030,6 +1058,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: finally: if origin_remote is not None: os.environ["SPARK_REMOTE"] = origin_remote + del os.environ["SPARK_LOCAL_CONNECT"] else: raise PySparkRuntimeError( errorClass="SESSION_OR_CONTEXT_EXISTS", diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 7d3900c7afbc5..2179a844b1e5e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -43,6 +43,7 @@ from pyspark.sql.types import StructType, Row from pyspark.sql.utils import dispatch_df_method + if TYPE_CHECKING: from py4j.java_gateway import JavaObject import pyarrow as pa @@ -65,6 +66,7 @@ ArrowMapIterFunction, DataFrameLike as PandasDataFrameLike, ) + from pyspark.sql.plot import PySparkPlotAccessor from pyspark.sql.metrics import ExecutionInfo @@ -1332,7 +1334,7 @@ def offset(self, num: int) -> "DataFrame": .. versionadded:: 3.4.0 .. versionchanged:: 3.5.0 - Supports vanilla PySpark. + Supports classic PySpark. Parameters ---------- @@ -6168,10 +6170,6 @@ def mapInPandas( | 1| 21| +---+---+ - Notes - ----- - This API is experimental - See Also -------- pyspark.sql.functions.pandas_udf @@ -6245,10 +6243,6 @@ def mapInArrow( | 1| 21| +---+---+ - Notes - ----- - This API is unstable, and for developers. - See Also -------- pyspark.sql.functions.pandas_udf @@ -6311,6 +6305,72 @@ def toPandas(self) -> "PandasDataFrameLike": """ ... + @dispatch_df_method + def transpose(self, indexColumn: Optional["ColumnOrName"] = None) -> "DataFrame": + """ + Transposes a DataFrame such that the values in the specified index column become the new + columns of the DataFrame. If no index column is provided, the first column is used as + the default. + + Please note: + - All columns except the index column must share a least common data type. Unless they + are the same data type, all columns are cast to the nearest common data type. + - The name of the column into which the original column names are transposed defaults + to "key". + - null values in the index column are excluded from the column names for the + transposed table, which are ordered in ascending order. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + indexColumn : str or :class:`Column`, optional + The single column that will be treated as the index for the transpose operation. This + column will be used to transform the DataFrame such that the values of the indexColumn + become the new columns in the transposed DataFrame. If not provided, the first column of + the DataFrame will be used as the default. + + Returns + ------- + :class:`DataFrame` + Transposed DataFrame. + + Notes + ----- + Supports Spark Connect. + + Examples + -------- + >>> df = spark.createDataFrame( + ... [("A", 1, 2), ("B", 3, 4)], + ... ["id", "val1", "val2"], + ... ) + >>> df.show() + +---+----+----+ + | id|val1|val2| + +---+----+----+ + | A| 1| 2| + | B| 3| 4| + +---+----+----+ + + >>> df.transpose().show() + +----+---+---+ + | key| A| B| + +----+---+---+ + |val1| 1| 3| + |val2| 2| 4| + +----+---+---+ + + >>> df.transpose(df.id).show() + +----+---+---+ + | key| A| B| + +----+---+---+ + |val1| 1| 3| + |val2| 2| 4| + +----+---+---+ + """ + ... + @property def executionInfo(self) -> Optional["ExecutionInfo"]: """ @@ -6336,6 +6396,32 @@ def executionInfo(self) -> Optional["ExecutionInfo"]: """ ... + @property + def plot(self) -> "PySparkPlotAccessor": + """ + Returns a :class:`PySparkPlotAccessor` for plotting functions. + + .. versionadded:: 4.0.0 + + Returns + ------- + :class:`PySparkPlotAccessor` + + Notes + ----- + This API is experimental. + + Examples + -------- + >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] + >>> columns = ["category", "int_val", "float_val"] + >>> df = spark.createDataFrame(data, columns) + >>> type(df.plot) + + >>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP + """ + ... + class DataFrameNaFunctions: """Functionality for working with missing data in :class:`DataFrame`. diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py index 72d42ae5e0cdf..a51c96a9d178f 100644 --- a/python/pyspark/sql/datasource.py +++ b/python/pyspark/sql/datasource.py @@ -23,9 +23,9 @@ from pyspark.errors import PySparkNotImplementedError if TYPE_CHECKING: + from pyarrow import RecordBatch from pyspark.sql.session import SparkSession - __all__ = [ "DataSource", "DataSourceReader", @@ -333,7 +333,7 @@ def partitions(self) -> Sequence[InputPartition]: ) @abstractmethod - def read(self, partition: InputPartition) -> Iterator[Tuple]: + def read(self, partition: InputPartition) -> Union[Iterator[Tuple], Iterator["RecordBatch"]]: """ Generates data for a given partition and returns an iterator of tuples or rows. @@ -350,9 +350,11 @@ def read(self, partition: InputPartition) -> Iterator[Tuple]: Returns ------- - iterator of tuples or :class:`Row`\\s + iterator of tuples or PyArrow's `RecordBatch` An iterator of tuples or rows. Each tuple or row will be converted to a row in the final DataFrame. + It can also return an iterator of PyArrow's `RecordBatch` if the data source + supports it. Examples -------- @@ -448,7 +450,7 @@ def partitions(self, start: dict, end: dict) -> Sequence[InputPartition]: ) @abstractmethod - def read(self, partition: InputPartition) -> Iterator[Tuple]: + def read(self, partition: InputPartition) -> Union[Iterator[Tuple], Iterator["RecordBatch"]]: """ Generates data for a given partition and returns an iterator of tuples or rows. @@ -470,9 +472,11 @@ def read(self, partition: InputPartition) -> Iterator[Tuple]: Returns ------- - iterator of tuples or :class:`Row`\\s + iterator of tuples or PyArrow's `RecordBatch` An iterator of tuples or rows. Each tuple or row will be converted to a row in the final DataFrame. + It can also return an iterator of PyArrow's `RecordBatch` if the data source + supports it. """ raise PySparkNotImplementedError( errorClass="NOT_IMPLEMENTED", diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 344ba8d009ac4..5f8d1c21a24f1 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -6015,9 +6015,9 @@ def grouping_id(*cols: "ColumnOrName") -> Column: @_try_remote_functions def count_min_sketch( col: "ColumnOrName", - eps: "ColumnOrName", - confidence: "ColumnOrName", - seed: "ColumnOrName", + eps: Union[Column, float], + confidence: Union[Column, float], + seed: Optional[Union[Column, int]] = None, ) -> Column: """ Returns a count-min sketch of a column with the given esp, confidence and seed. @@ -6031,13 +6031,24 @@ def count_min_sketch( ---------- col : :class:`~pyspark.sql.Column` or str target column to compute on. - eps : :class:`~pyspark.sql.Column` or str + eps : :class:`~pyspark.sql.Column` or float relative error, must be positive - confidence : :class:`~pyspark.sql.Column` or str + + .. versionchanged:: 4.0.0 + `eps` now accepts float value. + + confidence : :class:`~pyspark.sql.Column` or float confidence, must be positive and less than 1.0 - seed : :class:`~pyspark.sql.Column` or str + + .. versionchanged:: 4.0.0 + `confidence` now accepts float value. + + seed : :class:`~pyspark.sql.Column` or int, optional random seed + .. versionchanged:: 4.0.0 + `seed` now accepts int value. + Returns ------- :class:`~pyspark.sql.Column` @@ -6045,12 +6056,48 @@ def count_min_sketch( Examples -------- - >>> df = spark.createDataFrame([[1], [2], [1]], ['data']) - >>> df = df.agg(count_min_sketch(df.data, lit(0.5), lit(0.5), lit(1)).alias('sketch')) - >>> df.select(hex(df.sketch).alias('r')).collect() - [Row(r='0000000100000000000000030000000100000004000000005D8D6AB90000000000000000000000000000000200000000000000010000000000000000')] - """ - return _invoke_function_over_columns("count_min_sketch", col, eps, confidence, seed) + Example 1: Using columns as arguments + + >>> from pyspark.sql import functions as sf + >>> spark.range(100).select( + ... sf.hex(sf.count_min_sketch(sf.col("id"), sf.lit(3.0), sf.lit(0.1), sf.lit(1))) + ... ).show(truncate=False) + +------------------------------------------------------------------------+ + |hex(count_min_sketch(id, 3.0, 0.1, 1)) | + +------------------------------------------------------------------------+ + |0000000100000000000000640000000100000001000000005D8D6AB90000000000000064| + +------------------------------------------------------------------------+ + + Example 2: Using numbers as arguments + + >>> from pyspark.sql import functions as sf + >>> spark.range(100).select( + ... sf.hex(sf.count_min_sketch("id", 1.0, 0.3, 2)) + ... ).show(truncate=False) + +----------------------------------------------------------------------------------------+ + |hex(count_min_sketch(id, 1.0, 0.3, 2)) | + +----------------------------------------------------------------------------------------+ + |0000000100000000000000640000000100000002000000005D96391C00000000000000320000000000000032| + +----------------------------------------------------------------------------------------+ + + Example 3: Using a random seed + + >>> from pyspark.sql import functions as sf + >>> spark.range(100).select( + ... sf.hex(sf.count_min_sketch("id", sf.lit(1.5), 0.6)) + ... ).show(truncate=False) # doctest: +SKIP + +----------------------------------------------------------------------------------------------------------------------------------------+ + |hex(count_min_sketch(id, 1.5, 0.6, 2120704260)) | + +----------------------------------------------------------------------------------------------------------------------------------------+ + |0000000100000000000000640000000200000002000000005ADECCEE00000000153EBE090000000000000033000000000000003100000000000000320000000000000032| + +----------------------------------------------------------------------------------------------------------------------------------------+ + """ # noqa: E501 + _eps = lit(eps) + _conf = lit(confidence) + if seed is None: + return _invoke_function_over_columns("count_min_sketch", col, _eps, _conf) + else: + return _invoke_function_over_columns("count_min_sketch", col, _eps, _conf, lit(seed)) @_try_remote_functions @@ -7305,36 +7352,36 @@ def lag(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> | b| 2| +---+---+ >>> w = Window.partitionBy("c1").orderBy("c2") - >>> df.withColumn("previos_value", lag("c2").over(w)).show() - +---+---+-------------+ - | c1| c2|previos_value| - +---+---+-------------+ - | a| 1| NULL| - | a| 2| 1| - | a| 3| 2| - | b| 2| NULL| - | b| 8| 2| - +---+---+-------------+ - >>> df.withColumn("previos_value", lag("c2", 1, 0).over(w)).show() - +---+---+-------------+ - | c1| c2|previos_value| - +---+---+-------------+ - | a| 1| 0| - | a| 2| 1| - | a| 3| 2| - | b| 2| 0| - | b| 8| 2| - +---+---+-------------+ - >>> df.withColumn("previos_value", lag("c2", 2, -1).over(w)).show() - +---+---+-------------+ - | c1| c2|previos_value| - +---+---+-------------+ - | a| 1| -1| - | a| 2| -1| - | a| 3| 1| - | b| 2| -1| - | b| 8| -1| - +---+---+-------------+ + >>> df.withColumn("previous_value", lag("c2").over(w)).show() + +---+---+--------------+ + | c1| c2|previous_value| + +---+---+--------------+ + | a| 1| NULL| + | a| 2| 1| + | a| 3| 2| + | b| 2| NULL| + | b| 8| 2| + +---+---+--------------+ + >>> df.withColumn("previous_value", lag("c2", 1, 0).over(w)).show() + +---+---+--------------+ + | c1| c2|previous_value| + +---+---+--------------+ + | a| 1| 0| + | a| 2| 1| + | a| 3| 2| + | b| 2| 0| + | b| 8| 2| + +---+---+--------------+ + >>> df.withColumn("previous_value", lag("c2", 2, -1).over(w)).show() + +---+---+--------------+ + | c1| c2|previous_value| + +---+---+--------------+ + | a| 1| -1| + | a| 2| -1| + | a| 3| 1| + | b| 2| -1| + | b| 8| -1| + +---+---+--------------+ """ from pyspark.sql.classic.column import _to_java_column @@ -11241,13 +11288,27 @@ def sentences( ) -> Column: """ Splits a string into arrays of sentences, where each sentence is an array of words. - The 'language' and 'country' arguments are optional, and if omitted, the default locale is used. + The `language` and `country` arguments are optional, + When they are omitted: + 1.If they are both omitted, the `Locale.ROOT - locale(language='', country='')` is used. + The `Locale.ROOT` is regarded as the base locale of all locales, and is used as the + language/country neutral locale for the locale sensitive operations. + 2.If the `country` is omitted, the `locale(language, country='')` is used. + When they are null: + 1.If they are both `null`, the `Locale.US - locale(language='en', country='US')` is used. + 2.If the `language` is null and the `country` is not null, + the `Locale.US - locale(language='en', country='US')` is used. + 3.If the `language` is not null and the `country` is null, the `locale(language)` is used. + 4.If neither is `null`, the `locale(language, country)` is used. .. versionadded:: 3.2.0 .. versionchanged:: 3.4.0 Supports Spark Connect. + .. versionchanged:: 4.0.0 + Supports `sentences(string, language)`. + Parameters ---------- string : :class:`~pyspark.sql.Column` or str @@ -11271,6 +11332,12 @@ def sentences( +-----------------------------------+ |[[This, is, an, example, sentence]]| +-----------------------------------+ + >>> df.select(sentences(df.string, lit("en"))).show(truncate=False) + +-----------------------------------+ + |sentences(string, en, ) | + +-----------------------------------+ + |[[This, is, an, example, sentence]]| + +-----------------------------------+ >>> df = spark.createDataFrame([["Hello world. How are you?"]], ["s"]) >>> df.select(sentences("s")).show(truncate=False) +---------------------------------+ @@ -11289,7 +11356,9 @@ def sentences( @_try_remote_functions def substring( - str: "ColumnOrName", pos: Union["ColumnOrName", int], len: Union["ColumnOrName", int] + str: "ColumnOrName", + pos: Union["ColumnOrName", int], + len: Union["ColumnOrName", int], ) -> Column: """ Substring starts at `pos` and is of length `len` when str is String type or @@ -11328,16 +11397,59 @@ def substring( Examples -------- + Example 1: Using literal integers as arguments + + >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([('abcd',)], ['s',]) - >>> df.select(substring(df.s, 1, 2).alias('s')).collect() - [Row(s='ab')] + >>> df.select('*', sf.substring(df.s, 1, 2)).show() + +----+------------------+ + | s|substring(s, 1, 2)| + +----+------------------+ + |abcd| ab| + +----+------------------+ + + Example 2: Using columns as arguments + + >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([('Spark', 2, 3)], ['s', 'p', 'l']) - >>> df.select(substring(df.s, 2, df.l).alias('s')).collect() - [Row(s='par')] - >>> df.select(substring(df.s, df.p, 3).alias('s')).collect() - [Row(s='par')] - >>> df.select(substring(df.s, df.p, df.l).alias('s')).collect() - [Row(s='par')] + >>> df.select('*', sf.substring(df.s, 2, df.l)).show() + +-----+---+---+------------------+ + | s| p| l|substring(s, 2, l)| + +-----+---+---+------------------+ + |Spark| 2| 3| par| + +-----+---+---+------------------+ + + >>> df.select('*', sf.substring(df.s, df.p, 3)).show() + +-----+---+---+------------------+ + | s| p| l|substring(s, p, 3)| + +-----+---+---+------------------+ + |Spark| 2| 3| par| + +-----+---+---+------------------+ + + >>> df.select('*', sf.substring(df.s, df.p, df.l)).show() + +-----+---+---+------------------+ + | s| p| l|substring(s, p, l)| + +-----+---+---+------------------+ + |Spark| 2| 3| par| + +-----+---+---+------------------+ + + Example 3: Using column names as arguments + + >>> import pyspark.sql.functions as sf + >>> df = spark.createDataFrame([('Spark', 2, 3)], ['s', 'p', 'l']) + >>> df.select('*', sf.substring(df.s, 2, 'l')).show() + +-----+---+---+------------------+ + | s| p| l|substring(s, 2, l)| + +-----+---+---+------------------+ + |Spark| 2| 3| par| + +-----+---+---+------------------+ + + >>> df.select('*', sf.substring('s', 'p', 'l')).show() + +-----+---+---+------------------+ + | s| p| l|substring(s, p, l)| + +-----+---+---+------------------+ + |Spark| 2| 3| par| + +-----+---+---+------------------+ """ pos = _enum_to_value(pos) pos = lit(pos) if isinstance(pos, int) else pos @@ -16308,6 +16420,55 @@ def try_parse_json( return _invoke_function("try_parse_json", _to_java_column(col)) +@_try_remote_functions +def to_variant_object( + col: "ColumnOrName", +) -> Column: + """ + Converts a column containing nested inputs (array/map/struct) into a variants where maps and + structs are converted to variant objects which are unordered unlike SQL structs. Input maps can + only have string keys. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + a column with a nested schema or column name + + Returns + ------- + :class:`~pyspark.sql.Column` + a new column of VariantType. + + Examples + -------- + Example 1: Converting an array containing a nested struct into a variant + + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql.types import ArrayType, StructType, StructField, StringType, MapType + >>> schema = StructType([ + ... StructField("i", StringType(), True), + ... StructField("v", ArrayType(StructType([ + ... StructField("a", MapType(StringType(), StringType()), True) + ... ]), True)) + ... ]) + >>> data = [("1", [{"a": {"b": 2}}])] + >>> df = spark.createDataFrame(data, schema) + >>> df.select(sf.to_variant_object(df.v)) + DataFrame[to_variant_object(v): variant] + >>> df.select(sf.to_variant_object(df.v)).show(truncate=False) + +--------------------+ + |to_variant_object(v)| + +--------------------+ + |[{"a":{"b":"2"}}] | + +--------------------+ + """ + from pyspark.sql.classic.column import _to_java_column + + return _invoke_function("to_variant_object", _to_java_column(col)) + + @_try_remote_functions def parse_json( col: "ColumnOrName", @@ -16467,7 +16628,7 @@ def schema_of_variant(v: "ColumnOrName") -> Column: -------- >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ]) >>> df.select(schema_of_variant(parse_json(df.json)).alias("r")).collect() - [Row(r='STRUCT')] + [Row(r='OBJECT')] """ from pyspark.sql.classic.column import _to_java_column @@ -16495,7 +16656,7 @@ def schema_of_variant_agg(v: "ColumnOrName") -> Column: -------- >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ]) >>> df.select(schema_of_variant_agg(parse_json(df.json)).alias("r")).collect() - [Row(r='STRUCT')] + [Row(r='OBJECT')] """ from pyspark.sql.classic.column import _to_java_column diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 06834553ea96a..3173534c03c91 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -224,8 +224,6 @@ def applyInPandas( into memory, so the user should be aware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory. - This API is experimental. - See Also -------- pyspark.sql.functions.pandas_udf @@ -329,8 +327,6 @@ def applyInPandasWithState( Notes ----- This function requires a full shuffle. - - This API is experimental. """ from pyspark.sql import GroupedData @@ -484,8 +480,6 @@ def transformWithStateInPandas( Notes ----- This function requires a full shuffle. - - This API is experimental. """ from pyspark.sql import GroupedData @@ -683,10 +677,6 @@ class PandasCogroupedOps: .. versionchanged:: 3.4.0 Support Spark Connect. - - Notes - ----- - This API is experimental. """ def __init__(self, gd1: "GroupedData", gd2: "GroupedData"): @@ -778,8 +768,6 @@ def applyInPandas( into memory, so the user should be aware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory. - This API is experimental. - See Also -------- pyspark.sql.functions.pandas_udf diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 6203d4d19d866..076226865f3a7 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -510,8 +510,8 @@ def _create_batch(self, series): # If it returns a pd.Series, it should throw an error. if not isinstance(s, pd.DataFrame): raise PySparkValueError( - "A field of type StructType expects a pandas.DataFrame, " - "but got: %s" % str(type(s)) + "Invalid return type. Please make sure that the UDF returns a " + "pandas.DataFrame when the specified return type is StructType." ) arrs.append(self._create_struct_array(s, t)) else: diff --git a/python/pyspark/sql/plot/__init__.py b/python/pyspark/sql/plot/__init__.py new file mode 100644 index 0000000000000..6da07061b2a09 --- /dev/null +++ b/python/pyspark/sql/plot/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This package includes the plotting APIs for PySpark DataFrame. +""" +from pyspark.sql.plot.core import * # noqa: F403, F401 diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py new file mode 100644 index 0000000000000..392ef73b38845 --- /dev/null +++ b/python/pyspark/sql/plot/core.py @@ -0,0 +1,135 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, TYPE_CHECKING, Optional, Union +from types import ModuleType +from pyspark.errors import PySparkRuntimeError, PySparkValueError +from pyspark.sql.utils import require_minimum_plotly_version + + +if TYPE_CHECKING: + from pyspark.sql import DataFrame + import pandas as pd + from plotly.graph_objs import Figure + + +class PySparkTopNPlotBase: + def get_top_n(self, sdf: "DataFrame") -> "pd.DataFrame": + from pyspark.sql import SparkSession + + session = SparkSession.getActiveSession() + if session is None: + raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict()) + + max_rows = int( + session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type] + ) + pdf = sdf.limit(max_rows + 1).toPandas() + + self.partial = False + if len(pdf) > max_rows: + self.partial = True + pdf = pdf.iloc[:max_rows] + + return pdf + + +class PySparkSampledPlotBase: + def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame": + from pyspark.sql import SparkSession + + session = SparkSession.getActiveSession() + if session is None: + raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict()) + + sample_ratio = session.conf.get("spark.sql.pyspark.plotting.sample_ratio") + max_rows = int( + session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type] + ) + + if sample_ratio is None: + fraction = 1 / (sdf.count() / max_rows) + fraction = min(1.0, fraction) + else: + fraction = float(sample_ratio) + + sampled_sdf = sdf.sample(fraction=fraction) + pdf = sampled_sdf.toPandas() + + return pdf + + +class PySparkPlotAccessor: + plot_data_map = { + "line": PySparkSampledPlotBase().get_sampled, + } + _backends = {} # type: ignore[var-annotated] + + def __init__(self, data: "DataFrame"): + self.data = data + + def __call__( + self, kind: str = "line", backend: Optional[str] = None, **kwargs: Any + ) -> "Figure": + plot_backend = PySparkPlotAccessor._get_plot_backend(backend) + + return plot_backend.plot_pyspark(self.data, kind=kind, **kwargs) + + @staticmethod + def _get_plot_backend(backend: Optional[str] = None) -> ModuleType: + backend = backend or "plotly" + + if backend in PySparkPlotAccessor._backends: + return PySparkPlotAccessor._backends[backend] + + if backend == "plotly": + require_minimum_plotly_version() + else: + raise PySparkValueError( + errorClass="UNSUPPORTED_PLOT_BACKEND", + messageParameters={"backend": backend, "supported_backends": ", ".join(["plotly"])}, + ) + from pyspark.sql.plot import plotly as module + + return module + + def line(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": + """ + Plot DataFrame as lines. + + Parameters + ---------- + x : str + Name of column to use for the horizontal axis. + y : str or list of str + Name(s) of the column(s) to use for the vertical axis. Multiple columns can be plotted. + **kwargs : optional + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Examples + -------- + >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] + >>> columns = ["category", "int_val", "float_val"] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.line(x="category", y="int_val") # doctest: +SKIP + >>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP + """ + return self(kind="line", x=x, y=y, **kwargs) diff --git a/python/pyspark/sql/plot/plotly.py b/python/pyspark/sql/plot/plotly.py new file mode 100644 index 0000000000000..5efc19476057f --- /dev/null +++ b/python/pyspark/sql/plot/plotly.py @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import TYPE_CHECKING, Any + +from pyspark.sql.plot import PySparkPlotAccessor + +if TYPE_CHECKING: + from pyspark.sql import DataFrame + from plotly.graph_objs import Figure + + +def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure": + import plotly + + return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs) diff --git a/python/pyspark/sql/streaming/StateMessage_pb2.py b/python/pyspark/sql/streaming/StateMessage_pb2.py index 0f096e16d47ad..a22f004fd3048 100644 --- a/python/pyspark/sql/streaming/StateMessage_pb2.py +++ b/python/pyspark/sql/streaming/StateMessage_pb2.py @@ -16,14 +16,12 @@ # # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE # source: StateMessage.proto -# Protobuf Python Version: 5.27.1 """Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) @@ -31,16 +29,17 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xe9\x02\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00\x42\x08\n\x06method"H\n\rStateResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x0c"\x89\x03\n\x15StatefulProcessorCall\x12X\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00\x12Y\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12X\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12W\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x42\x08\n\x06method"z\n\x14StateVariableRequest\x12X\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00\x42\x08\n\x06method"\xe0\x01\n\x1aImplicitGroupingKeyRequest\x12X\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00\x12^\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00\x42\x08\n\x06method"5\n\x10StateCallCommand\x12\x11\n\tstateName\x18\x01 \x01(\t\x12\x0e\n\x06schema\x18\x02 \x01(\t"\xe1\x02\n\x0eValueStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12\x42\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00\x12\\\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00\x12\x46\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x1d\n\x0eSetImplicitKey\x12\x0b\n\x03key\x18\x01 \x01(\x0c"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"!\n\x10ValueStateUpdate\x12\r\n\x05value\x18\x01 \x01(\x0c"\x07\n\x05\x43lear"\\\n\x0eSetHandleState\x12J\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleState*K\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\n\n\x06\x43LOSED\x10\x03\x62\x06proto3' # noqa: E501 + b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xe9\x02\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00\x42\x08\n\x06method"H\n\rStateResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x0c"\x89\x03\n\x15StatefulProcessorCall\x12X\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00\x12Y\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12X\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12W\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x42\x08\n\x06method"z\n\x14StateVariableRequest\x12X\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00\x42\x08\n\x06method"\xe0\x01\n\x1aImplicitGroupingKeyRequest\x12X\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00\x12^\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00\x42\x08\n\x06method"}\n\x10StateCallCommand\x12\x11\n\tstateName\x18\x01 \x01(\t\x12\x0e\n\x06schema\x18\x02 \x01(\t\x12\x46\n\x03ttl\x18\x03 \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.TTLConfig"\xe1\x02\n\x0eValueStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12\x42\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00\x12\\\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00\x12\x46\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x1d\n\x0eSetImplicitKey\x12\x0b\n\x03key\x18\x01 \x01(\x0c"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"!\n\x10ValueStateUpdate\x12\r\n\x05value\x18\x01 \x01(\x0c"\x07\n\x05\x43lear"\\\n\x0eSetHandleState\x12J\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleState"\x1f\n\tTTLConfig\x12\x12\n\ndurationMs\x18\x01 \x01(\x05*K\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\n\n\x06\x43LOSED\x10\x03\x62\x06proto3' # noqa: E501 ) _globals = globals() + _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "StateMessage_pb2", _globals) if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals["_HANDLESTATE"]._serialized_start = 1873 - _globals["_HANDLESTATE"]._serialized_end = 1948 + DESCRIPTOR._options = None + _globals["_HANDLESTATE"]._serialized_start = 1978 + _globals["_HANDLESTATE"]._serialized_end = 2053 _globals["_STATEREQUEST"]._serialized_start = 71 _globals["_STATEREQUEST"]._serialized_end = 432 _globals["_STATERESPONSE"]._serialized_start = 434 @@ -52,21 +51,23 @@ _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1029 _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 1253 _globals["_STATECALLCOMMAND"]._serialized_start = 1255 - _globals["_STATECALLCOMMAND"]._serialized_end = 1308 - _globals["_VALUESTATECALL"]._serialized_start = 1311 - _globals["_VALUESTATECALL"]._serialized_end = 1664 - _globals["_SETIMPLICITKEY"]._serialized_start = 1666 - _globals["_SETIMPLICITKEY"]._serialized_end = 1695 - _globals["_REMOVEIMPLICITKEY"]._serialized_start = 1697 - _globals["_REMOVEIMPLICITKEY"]._serialized_end = 1716 - _globals["_EXISTS"]._serialized_start = 1718 - _globals["_EXISTS"]._serialized_end = 1726 - _globals["_GET"]._serialized_start = 1728 - _globals["_GET"]._serialized_end = 1733 - _globals["_VALUESTATEUPDATE"]._serialized_start = 1735 - _globals["_VALUESTATEUPDATE"]._serialized_end = 1768 - _globals["_CLEAR"]._serialized_start = 1770 - _globals["_CLEAR"]._serialized_end = 1777 - _globals["_SETHANDLESTATE"]._serialized_start = 1779 - _globals["_SETHANDLESTATE"]._serialized_end = 1871 + _globals["_STATECALLCOMMAND"]._serialized_end = 1380 + _globals["_VALUESTATECALL"]._serialized_start = 1383 + _globals["_VALUESTATECALL"]._serialized_end = 1736 + _globals["_SETIMPLICITKEY"]._serialized_start = 1738 + _globals["_SETIMPLICITKEY"]._serialized_end = 1767 + _globals["_REMOVEIMPLICITKEY"]._serialized_start = 1769 + _globals["_REMOVEIMPLICITKEY"]._serialized_end = 1788 + _globals["_EXISTS"]._serialized_start = 1790 + _globals["_EXISTS"]._serialized_end = 1798 + _globals["_GET"]._serialized_start = 1800 + _globals["_GET"]._serialized_end = 1805 + _globals["_VALUESTATEUPDATE"]._serialized_start = 1807 + _globals["_VALUESTATEUPDATE"]._serialized_end = 1840 + _globals["_CLEAR"]._serialized_start = 1842 + _globals["_CLEAR"]._serialized_end = 1849 + _globals["_SETHANDLESTATE"]._serialized_start = 1851 + _globals["_SETHANDLESTATE"]._serialized_end = 1943 + _globals["_TTLCONFIG"]._serialized_start = 1945 + _globals["_TTLCONFIG"]._serialized_end = 1976 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/streaming/StateMessage_pb2.pyi b/python/pyspark/sql/streaming/StateMessage_pb2.pyi index 0e6f1fb065881..1ab48a27c8f87 100644 --- a/python/pyspark/sql/streaming/StateMessage_pb2.pyi +++ b/python/pyspark/sql/streaming/StateMessage_pb2.pyi @@ -13,163 +13,167 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +# from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message -from typing import ( - ClassVar as _ClassVar, - Mapping as _Mapping, - Optional as _Optional, - Union as _Union, -) +from typing import ClassVar, Mapping, Optional, Union +CLOSED: HandleState +CREATED: HandleState +DATA_PROCESSED: HandleState DESCRIPTOR: _descriptor.FileDescriptor +INITIALIZED: HandleState -class HandleState(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): +class Clear(_message.Message): __slots__ = () - CREATED: _ClassVar[HandleState] - INITIALIZED: _ClassVar[HandleState] - DATA_PROCESSED: _ClassVar[HandleState] - CLOSED: _ClassVar[HandleState] + def __init__(self) -> None: ... -CREATED: HandleState -INITIALIZED: HandleState -DATA_PROCESSED: HandleState -CLOSED: HandleState +class Exists(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class Get(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class ImplicitGroupingKeyRequest(_message.Message): + __slots__ = ["removeImplicitKey", "setImplicitKey"] + REMOVEIMPLICITKEY_FIELD_NUMBER: ClassVar[int] + SETIMPLICITKEY_FIELD_NUMBER: ClassVar[int] + removeImplicitKey: RemoveImplicitKey + setImplicitKey: SetImplicitKey + def __init__( + self, + setImplicitKey: Optional[Union[SetImplicitKey, Mapping]] = ..., + removeImplicitKey: Optional[Union[RemoveImplicitKey, Mapping]] = ..., + ) -> None: ... + +class RemoveImplicitKey(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class SetHandleState(_message.Message): + __slots__ = ["state"] + STATE_FIELD_NUMBER: ClassVar[int] + state: HandleState + def __init__(self, state: Optional[Union[HandleState, str]] = ...) -> None: ... + +class SetImplicitKey(_message.Message): + __slots__ = ["key"] + KEY_FIELD_NUMBER: ClassVar[int] + key: bytes + def __init__(self, key: Optional[bytes] = ...) -> None: ... + +class StateCallCommand(_message.Message): + __slots__ = ["schema", "stateName", "ttl"] + SCHEMA_FIELD_NUMBER: ClassVar[int] + STATENAME_FIELD_NUMBER: ClassVar[int] + TTL_FIELD_NUMBER: ClassVar[int] + schema: str + stateName: str + ttl: TTLConfig + def __init__( + self, + stateName: Optional[str] = ..., + schema: Optional[str] = ..., + ttl: Optional[Union[TTLConfig, Mapping]] = ..., + ) -> None: ... class StateRequest(_message.Message): - __slots__ = ( - "version", - "statefulProcessorCall", - "stateVariableRequest", + __slots__ = [ "implicitGroupingKeyRequest", - ) - VERSION_FIELD_NUMBER: _ClassVar[int] - STATEFULPROCESSORCALL_FIELD_NUMBER: _ClassVar[int] - STATEVARIABLEREQUEST_FIELD_NUMBER: _ClassVar[int] - IMPLICITGROUPINGKEYREQUEST_FIELD_NUMBER: _ClassVar[int] - version: int - statefulProcessorCall: StatefulProcessorCall - stateVariableRequest: StateVariableRequest + "stateVariableRequest", + "statefulProcessorCall", + "version", + ] + IMPLICITGROUPINGKEYREQUEST_FIELD_NUMBER: ClassVar[int] + STATEFULPROCESSORCALL_FIELD_NUMBER: ClassVar[int] + STATEVARIABLEREQUEST_FIELD_NUMBER: ClassVar[int] + VERSION_FIELD_NUMBER: ClassVar[int] implicitGroupingKeyRequest: ImplicitGroupingKeyRequest + stateVariableRequest: StateVariableRequest + statefulProcessorCall: StatefulProcessorCall + version: int def __init__( self, - version: _Optional[int] = ..., - statefulProcessorCall: _Optional[_Union[StatefulProcessorCall, _Mapping]] = ..., - stateVariableRequest: _Optional[_Union[StateVariableRequest, _Mapping]] = ..., - implicitGroupingKeyRequest: _Optional[_Union[ImplicitGroupingKeyRequest, _Mapping]] = ..., + version: Optional[int] = ..., + statefulProcessorCall: Optional[Union[StatefulProcessorCall, Mapping]] = ..., + stateVariableRequest: Optional[Union[StateVariableRequest, Mapping]] = ..., + implicitGroupingKeyRequest: Optional[Union[ImplicitGroupingKeyRequest, Mapping]] = ..., ) -> None: ... class StateResponse(_message.Message): - __slots__ = ("statusCode", "errorMessage", "value") - STATUSCODE_FIELD_NUMBER: _ClassVar[int] - ERRORMESSAGE_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - statusCode: int + __slots__ = ["errorMessage", "statusCode", "value"] + ERRORMESSAGE_FIELD_NUMBER: ClassVar[int] + STATUSCODE_FIELD_NUMBER: ClassVar[int] + VALUE_FIELD_NUMBER: ClassVar[int] errorMessage: str + statusCode: int value: bytes - def __init__( - self, statusCode: _Optional[int] = ..., errorMessage: _Optional[str] = ... - ) -> None: ... - -class StatefulProcessorCall(_message.Message): - __slots__ = ("setHandleState", "getValueState", "getListState", "getMapState") - SETHANDLESTATE_FIELD_NUMBER: _ClassVar[int] - GETVALUESTATE_FIELD_NUMBER: _ClassVar[int] - GETLISTSTATE_FIELD_NUMBER: _ClassVar[int] - GETMAPSTATE_FIELD_NUMBER: _ClassVar[int] - setHandleState: SetHandleState - getValueState: StateCallCommand - getListState: StateCallCommand - getMapState: StateCallCommand def __init__( self, - setHandleState: _Optional[_Union[SetHandleState, _Mapping]] = ..., - getValueState: _Optional[_Union[StateCallCommand, _Mapping]] = ..., - getListState: _Optional[_Union[StateCallCommand, _Mapping]] = ..., - getMapState: _Optional[_Union[StateCallCommand, _Mapping]] = ..., + statusCode: Optional[int] = ..., + errorMessage: Optional[str] = ..., + value: Optional[bytes] = ..., ) -> None: ... class StateVariableRequest(_message.Message): - __slots__ = ("valueStateCall",) - VALUESTATECALL_FIELD_NUMBER: _ClassVar[int] + __slots__ = ["valueStateCall"] + VALUESTATECALL_FIELD_NUMBER: ClassVar[int] valueStateCall: ValueStateCall - def __init__( - self, valueStateCall: _Optional[_Union[ValueStateCall, _Mapping]] = ... - ) -> None: ... + def __init__(self, valueStateCall: Optional[Union[ValueStateCall, Mapping]] = ...) -> None: ... -class ImplicitGroupingKeyRequest(_message.Message): - __slots__ = ("setImplicitKey", "removeImplicitKey") - SETIMPLICITKEY_FIELD_NUMBER: _ClassVar[int] - REMOVEIMPLICITKEY_FIELD_NUMBER: _ClassVar[int] - setImplicitKey: SetImplicitKey - removeImplicitKey: RemoveImplicitKey +class StatefulProcessorCall(_message.Message): + __slots__ = ["getListState", "getMapState", "getValueState", "setHandleState"] + GETLISTSTATE_FIELD_NUMBER: ClassVar[int] + GETMAPSTATE_FIELD_NUMBER: ClassVar[int] + GETVALUESTATE_FIELD_NUMBER: ClassVar[int] + SETHANDLESTATE_FIELD_NUMBER: ClassVar[int] + getListState: StateCallCommand + getMapState: StateCallCommand + getValueState: StateCallCommand + setHandleState: SetHandleState def __init__( self, - setImplicitKey: _Optional[_Union[SetImplicitKey, _Mapping]] = ..., - removeImplicitKey: _Optional[_Union[RemoveImplicitKey, _Mapping]] = ..., + setHandleState: Optional[Union[SetHandleState, Mapping]] = ..., + getValueState: Optional[Union[StateCallCommand, Mapping]] = ..., + getListState: Optional[Union[StateCallCommand, Mapping]] = ..., + getMapState: Optional[Union[StateCallCommand, Mapping]] = ..., ) -> None: ... -class StateCallCommand(_message.Message): - __slots__ = ("stateName", "schema") - STATENAME_FIELD_NUMBER: _ClassVar[int] - SCHEMA_FIELD_NUMBER: _ClassVar[int] - stateName: str - schema: str - def __init__(self, stateName: _Optional[str] = ..., schema: _Optional[str] = ...) -> None: ... +class TTLConfig(_message.Message): + __slots__ = ["durationMs"] + DURATIONMS_FIELD_NUMBER: ClassVar[int] + durationMs: int + def __init__(self, durationMs: Optional[int] = ...) -> None: ... class ValueStateCall(_message.Message): - __slots__ = ("stateName", "exists", "get", "valueStateUpdate", "clear") - STATENAME_FIELD_NUMBER: _ClassVar[int] - EXISTS_FIELD_NUMBER: _ClassVar[int] - GET_FIELD_NUMBER: _ClassVar[int] - VALUESTATEUPDATE_FIELD_NUMBER: _ClassVar[int] - CLEAR_FIELD_NUMBER: _ClassVar[int] - stateName: str + __slots__ = ["clear", "exists", "get", "stateName", "valueStateUpdate"] + CLEAR_FIELD_NUMBER: ClassVar[int] + EXISTS_FIELD_NUMBER: ClassVar[int] + GET_FIELD_NUMBER: ClassVar[int] + STATENAME_FIELD_NUMBER: ClassVar[int] + VALUESTATEUPDATE_FIELD_NUMBER: ClassVar[int] + clear: Clear exists: Exists get: Get + stateName: str valueStateUpdate: ValueStateUpdate - clear: Clear def __init__( self, - stateName: _Optional[str] = ..., - exists: _Optional[_Union[Exists, _Mapping]] = ..., - get: _Optional[_Union[Get, _Mapping]] = ..., - valueStateUpdate: _Optional[_Union[ValueStateUpdate, _Mapping]] = ..., - clear: _Optional[_Union[Clear, _Mapping]] = ..., + stateName: Optional[str] = ..., + exists: Optional[Union[Exists, Mapping]] = ..., + get: Optional[Union[Get, Mapping]] = ..., + valueStateUpdate: Optional[Union[ValueStateUpdate, Mapping]] = ..., + clear: Optional[Union[Clear, Mapping]] = ..., ) -> None: ... -class SetImplicitKey(_message.Message): - __slots__ = ("key",) - KEY_FIELD_NUMBER: _ClassVar[int] - key: bytes - def __init__(self, key: _Optional[bytes] = ...) -> None: ... - -class RemoveImplicitKey(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class Exists(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class Get(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - class ValueStateUpdate(_message.Message): - __slots__ = ("value",) - VALUE_FIELD_NUMBER: _ClassVar[int] + __slots__ = ["value"] + VALUE_FIELD_NUMBER: ClassVar[int] value: bytes - def __init__(self, value: _Optional[bytes] = ...) -> None: ... + def __init__(self, value: Optional[bytes] = ...) -> None: ... -class Clear(_message.Message): +class HandleState(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): __slots__ = () - def __init__(self) -> None: ... - -class SetHandleState(_message.Message): - __slots__ = ("state",) - STATE_FIELD_NUMBER: _ClassVar[int] - state: HandleState - def __init__(self, state: _Optional[_Union[HandleState, str]] = ...) -> None: ... diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py index a378eec2b6175..9045c81e287cd 100644 --- a/python/pyspark/sql/streaming/stateful_processor.py +++ b/python/pyspark/sql/streaming/stateful_processor.py @@ -88,7 +88,9 @@ class StatefulProcessorHandle: def __init__(self, stateful_processor_api_client: StatefulProcessorApiClient) -> None: self.stateful_processor_api_client = stateful_processor_api_client - def getValueState(self, state_name: str, schema: Union[StructType, str]) -> ValueState: + def getValueState( + self, state_name: str, schema: Union[StructType, str], ttl_duration_ms: Optional[int] = None + ) -> ValueState: """ Function to create new or return existing single value state variable of given type. The user must ensure to call this function only within the `init()` method of the @@ -101,8 +103,13 @@ def getValueState(self, state_name: str, schema: Union[StructType, str]) -> Valu schema : :class:`pyspark.sql.types.DataType` or str The schema of the state variable. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + ttlDurationMs: int + Time to live duration of the state in milliseconds. State values will not be returned + past ttlDuration and will be eventually removed from the state store. Any state update + resets the expiration time to current processing time plus ttlDuration. + If ttl is not specified the state will never expire. """ - self.stateful_processor_api_client.get_value_state(state_name, schema) + self.stateful_processor_api_client.get_value_state(state_name, schema, ttl_duration_ms) return ValueState(ValueStateClient(self.stateful_processor_api_client), state_name, schema) diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index 080d7739992ec..9703aa17d3474 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -17,7 +17,7 @@ from enum import Enum import os import socket -from typing import Any, Union, cast, Tuple +from typing import Any, Union, Optional, cast, Tuple from pyspark.serializers import write_int, read_int, UTF8Deserializer from pyspark.sql.types import StructType, _parse_datatype_string, Row @@ -101,7 +101,9 @@ def remove_implicit_key(self) -> None: # TODO(SPARK-49233): Classify errors thrown by internal methods. raise PySparkRuntimeError(f"Error removing implicit key: " f"{response_message[1]}") - def get_value_state(self, state_name: str, schema: Union[StructType, str]) -> None: + def get_value_state( + self, state_name: str, schema: Union[StructType, str], ttl_duration_ms: Optional[int] + ) -> None: import pyspark.sql.streaming.StateMessage_pb2 as stateMessage if isinstance(schema, str): @@ -110,6 +112,8 @@ def get_value_state(self, state_name: str, schema: Union[StructType, str]) -> No state_call_command = stateMessage.StateCallCommand() state_call_command.stateName = state_name state_call_command.schema = schema.json() + if ttl_duration_ms is not None: + state_call_command.ttl.durationMs = ttl_duration_ms call = stateMessage.StatefulProcessorCall(getValueState=state_call_command) message = stateMessage.StateRequest(statefulProcessorCall=call) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 196c9eb5d81d8..5deb73a0ccf90 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -417,6 +417,18 @@ def checks(): for b in parameters: not_found_fails(b) + def test_observed_session_id(self): + stub = self._stub_with([self.response, self.finished]) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) + session_id = "test-session-id" + + reattach = ite._create_reattach_execute_request() + self.assertEqual(reattach.client_observed_server_side_session_id, "") + + self.request.client_observed_server_side_session_id = session_id + reattach = ite._create_reattach_execute_request() + self.assertEqual(reattach.client_observed_server_side_session_id, session_id) + if __name__ == "__main__": from pyspark.sql.tests.connect.client.test_client import * # noqa: F401 diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index c3d0d28017e60..f05f982d2d14e 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -176,7 +176,7 @@ def test_slow_query(self): def test_listener_throw(self): """ - Following Vanilla Spark's behavior, when the callback of user-defined listener throws, + Following classic Spark's behavior, when the callback of user-defined listener throws, other listeners should still proceed. """ diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index 51ce1cd685210..e29873173cc3a 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -2572,7 +2572,7 @@ def test_function_parity(self): cf_fn = {name for (name, value) in getmembers(CF, isfunction) if name[0] != "_"} - # Functions in vanilla PySpark we do not expect to be available in Spark Connect + # Functions in classic PySpark we do not expect to be available in Spark Connect sf_excluded_fn = set() self.assertEqual( @@ -2581,7 +2581,7 @@ def test_function_parity(self): "Missing functions in Spark Connect not as expected", ) - # Functions in Spark Connect we do not expect to be available in vanilla PySpark + # Functions in Spark Connect we do not expect to be available in classic PySpark cf_excluded_fn = { "check_dependencies", # internal helper function } @@ -2589,7 +2589,7 @@ def test_function_parity(self): self.assertEqual( cf_fn - sf_fn, cf_excluded_fn, - "Missing functions in vanilla PySpark not as expected", + "Missing functions in classic PySpark not as expected", ) # SPARK-45216: Fix non-deterministic seeded Dataset APIs diff --git a/python/pyspark/sql/tests/connect/test_parity_frame_plot.py b/python/pyspark/sql/tests/connect/test_parity_frame_plot.py new file mode 100644 index 0000000000000..c69e438bf7eb0 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_frame_plot.py @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.sql.tests.plot.test_frame_plot import DataFramePlotTestsMixin + + +class FramePlotParityTests(DataFramePlotTestsMixin, ReusedConnectTestCase): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_parity_frame_plot import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py b/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py new file mode 100644 index 0000000000000..78508fe533379 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.sql.tests.plot.test_frame_plot_plotly import DataFramePlotPlotlyTestsMixin + + +class FramePlotPlotlyParityTests(DataFramePlotPlotlyTestsMixin, ReusedConnectTestCase): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_parity_frame_plot_plotly import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index f05a601094a5d..8ad24704de3a4 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -16,6 +16,7 @@ # import os +import time import tempfile from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle from typing import Iterator @@ -32,6 +33,7 @@ Row, IntegerType, ) +from pyspark.testing import assertDataFrameEqual from pyspark.testing.sqlutils import ( ReusedSQLTestCase, have_pandas, @@ -59,20 +61,18 @@ def conf(cls): ) return cfg + def _prepare_input_data(self, input_path, col1, col2): + with open(input_path, "w") as fw: + for e1, e2 in zip(col1, col2): + fw.write(f"{e1}, {e2}\n") + def _prepare_test_resource1(self, input_path): - with open(input_path + "/text-test1.txt", "w") as fw: - fw.write("0, 123\n") - fw.write("0, 46\n") - fw.write("1, 146\n") - fw.write("1, 346\n") + self._prepare_input_data(input_path + "/text-test1.txt", [0, 0, 1, 1], [123, 46, 146, 346]) def _prepare_test_resource2(self, input_path): - with open(input_path + "/text-test2.txt", "w") as fw: - fw.write("0, 123\n") - fw.write("0, 223\n") - fw.write("0, 323\n") - fw.write("1, 246\n") - fw.write("1, 6\n") + self._prepare_input_data( + input_path + "/text-test2.txt", [0, 0, 0, 1, 1], [123, 223, 323, 246, 6] + ) def _build_test_df(self, input_path): df = self.spark.readStream.format("text").option("maxFilesPerTrigger", 1).load(input_path) @@ -84,7 +84,7 @@ def _build_test_df(self, input_path): return df_final def _test_transform_with_state_in_pandas_basic( - self, stateful_processor, check_results, single_batch=False + self, stateful_processor, check_results, single_batch=False, timeMode="None" ): input_path = tempfile.mkdtemp() self._prepare_test_resource1(input_path) @@ -110,7 +110,7 @@ def _test_transform_with_state_in_pandas_basic( statefulProcessor=stateful_processor, outputStructType=output_schema, outputMode="Update", - timeMode="None", + timeMode=timeMode, ) .writeStream.queryName("this_query") .foreachBatch(check_results) @@ -211,6 +211,99 @@ def test_transform_with_state_in_pandas_query_restarts(self): Row(id="1", countAsString="2"), } + # test value state with ttl has the same behavior as value state when + # state doesn't expire. + def test_value_state_ttl_basic(self): + def check_results(batch_df, batch_id): + if batch_id == 0: + assert set(batch_df.sort("id").collect()) == { + Row(id="0", countAsString="2"), + Row(id="1", countAsString="2"), + } + else: + assert set(batch_df.sort("id").collect()) == { + Row(id="0", countAsString="3"), + Row(id="1", countAsString="2"), + } + + self._test_transform_with_state_in_pandas_basic( + SimpleTTLStatefulProcessor(), check_results, False, "processingTime" + ) + + def test_value_state_ttl_expiration(self): + def check_results(batch_df, batch_id): + if batch_id == 0: + assertDataFrameEqual( + batch_df, + [ + Row(id="ttl-count-0", count=1), + Row(id="count-0", count=1), + Row(id="ttl-count-1", count=1), + Row(id="count-1", count=1), + ], + ) + elif batch_id == 1: + assertDataFrameEqual( + batch_df, + [ + Row(id="ttl-count-0", count=2), + Row(id="count-0", count=2), + Row(id="ttl-count-1", count=2), + Row(id="count-1", count=2), + ], + ) + elif batch_id == 2: + # ttl-count-0 expire and restart from count 0. + # ttl-count-1 get reset in batch 1 and keep the state + # non-ttl state never expires + assertDataFrameEqual( + batch_df, + [ + Row(id="ttl-count-0", count=1), + Row(id="count-0", count=3), + Row(id="ttl-count-1", count=3), + Row(id="count-1", count=3), + ], + ) + if batch_id == 0 or batch_id == 1: + time.sleep(6) + + input_dir = tempfile.TemporaryDirectory() + input_path = input_dir.name + try: + df = self._build_test_df(input_path) + self._prepare_input_data(input_path + "/batch1.txt", [1, 0], [0, 0]) + self._prepare_input_data(input_path + "/batch2.txt", [1, 0], [0, 0]) + self._prepare_input_data(input_path + "/batch3.txt", [1, 0], [0, 0]) + for q in self.spark.streams.active: + q.stop() + output_schema = StructType( + [ + StructField("id", StringType(), True), + StructField("count", IntegerType(), True), + ] + ) + + q = ( + df.groupBy("id") + .transformWithStateInPandas( + statefulProcessor=TTLStatefulProcessor(), + outputStructType=output_schema, + outputMode="Update", + timeMode="processingTime", + ) + .writeStream.foreachBatch(check_results) + .outputMode("update") + .start() + ) + self.assertTrue(q.isActive) + q.processAllAvailable() + q.stop() + q.awaitTermination() + self.assertTrue(q.exception() is None) + finally: + input_dir.cleanup() + class SimpleStatefulProcessor(StatefulProcessor): dict = {0: {"0": 1, "1": 2}, 1: {"0": 4, "1": 3}} @@ -246,6 +339,43 @@ def close(self) -> None: pass +# A stateful processor that inherit all behavior of SimpleStatefulProcessor except that it use +# ttl state with a large timeout. +class SimpleTTLStatefulProcessor(SimpleStatefulProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + state_schema = StructType([StructField("value", IntegerType(), True)]) + self.num_violations_state = handle.getValueState("numViolations", state_schema, 30000) + + +class TTLStatefulProcessor(StatefulProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + state_schema = StructType([StructField("value", IntegerType(), True)]) + self.ttl_count_state = handle.getValueState("ttl-state", state_schema, 10000) + self.count_state = handle.getValueState("state", state_schema) + + def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]: + count = 0 + ttl_count = 0 + id = key[0] + if self.count_state.exists(): + count = self.count_state.get()[0] + if self.ttl_count_state.exists(): + ttl_count = self.ttl_count_state.get()[0] + for pdf in rows: + pdf_count = pdf.count().get("temperature") + count += pdf_count + ttl_count += pdf_count + + self.count_state.update((count,)) + # skip updating state for the 2nd batch so that ttl state expire + if not (ttl_count == 2 and id == "0"): + self.ttl_count_state.update((ttl_count,)) + yield pd.DataFrame({"id": [f"ttl-count-{id}", f"count-{id}"], "count": [ttl_count, count]}) + + def close(self) -> None: + pass + + class InvalidSimpleStatefulProcessor(StatefulProcessor): def init(self, handle: StatefulProcessorHandle) -> None: state_schema = StructType([StructField("value", IntegerType(), True)]) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py b/python/pyspark/sql/tests/pandas/test_pandas_udf.py index 6720dfc37d0cc..228fc30b497cc 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py @@ -339,6 +339,19 @@ def noop(s: pd.Series) -> pd.Series: self.assertEqual(df.schema[0].dataType.simpleString(), "interval day to second") self.assertEqual(df.first()[0], datetime.timedelta(microseconds=123)) + def test_pandas_udf_return_type_error(self): + import pandas as pd + + @pandas_udf("s string") + def upper(s: pd.Series) -> pd.Series: + return s.str.upper() + + df = self.spark.createDataFrame([("a",)], schema="s string") + + self.assertRaisesRegex( + PythonException, "Invalid return type", df.select(upper("s")).collect + ) + class PandasUDFTests(PandasUDFTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/plot/__init__.py b/python/pyspark/sql/tests/plot/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/sql/tests/plot/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/sql/tests/plot/test_frame_plot.py b/python/pyspark/sql/tests/plot/test_frame_plot.py new file mode 100644 index 0000000000000..f753b5ab3db72 --- /dev/null +++ b/python/pyspark/sql/tests/plot/test_frame_plot.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from pyspark.errors import PySparkValueError +from pyspark.sql import Row +from pyspark.sql.plot import PySparkSampledPlotBase, PySparkTopNPlotBase +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message + + +@unittest.skipIf(not have_plotly, plotly_requirement_message) +class DataFramePlotTestsMixin: + def test_backend(self): + accessor = self.spark.range(2).plot + backend = accessor._get_plot_backend() + self.assertEqual(backend.__name__, "pyspark.sql.plot.plotly") + + with self.assertRaises(PySparkValueError) as pe: + accessor._get_plot_backend("matplotlib") + + self.check_error( + exception=pe.exception, + errorClass="UNSUPPORTED_PLOT_BACKEND", + messageParameters={"backend": "matplotlib", "supported_backends": "plotly"}, + ) + + def test_topn_max_rows(self): + try: + self.spark.conf.set("spark.sql.pyspark.plotting.max_rows", "1000") + sdf = self.spark.range(2500) + pdf = PySparkTopNPlotBase().get_top_n(sdf) + self.assertEqual(len(pdf), 1000) + finally: + self.spark.conf.unset("spark.sql.pyspark.plotting.max_rows") + + def test_sampled_plot_with_ratio(self): + try: + self.spark.conf.set("spark.sql.pyspark.plotting.sample_ratio", "0.5") + data = [Row(a=i, b=i + 1, c=i + 2, d=i + 3) for i in range(2500)] + sdf = self.spark.createDataFrame(data) + pdf = PySparkSampledPlotBase().get_sampled(sdf) + self.assertEqual(round(len(pdf) / 2500, 1), 0.5) + finally: + self.spark.conf.unset("spark.sql.pyspark.plotting.sample_ratio") + + def test_sampled_plot_with_max_rows(self): + data = [Row(a=i, b=i + 1, c=i + 2, d=i + 3) for i in range(2000)] + sdf = self.spark.createDataFrame(data) + pdf = PySparkSampledPlotBase().get_sampled(sdf) + self.assertEqual(round(len(pdf) / 2000, 1), 0.5) + + +class DataFramePlotTests(DataFramePlotTestsMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.plot.test_frame_plot import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py new file mode 100644 index 0000000000000..72a3ed267d192 --- /dev/null +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import pyspark.sql.plot # noqa: F401 +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message + + +@unittest.skipIf(not have_plotly, plotly_requirement_message) +class DataFramePlotPlotlyTestsMixin: + @property + def sdf(self): + data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] + columns = ["category", "int_val", "float_val"] + return self.spark.createDataFrame(data, columns) + + def _check_fig_data(self, fig_data, expected_x, expected_y, expected_name=""): + self.assertEqual(fig_data["mode"], "lines") + self.assertEqual(fig_data["type"], "scatter") + self.assertEqual(fig_data["xaxis"], "x") + self.assertEqual(list(fig_data["x"]), expected_x) + self.assertEqual(fig_data["yaxis"], "y") + self.assertEqual(list(fig_data["y"]), expected_y) + self.assertEqual(fig_data["name"], expected_name) + + def test_line_plot(self): + # single column as vertical axis + fig = self.sdf.plot(kind="line", x="category", y="int_val") + self._check_fig_data(fig["data"][0], ["A", "B", "C"], [10, 30, 20]) + + # multiple columns as vertical axis + fig = self.sdf.plot.line(x="category", y=["int_val", "float_val"]) + self._check_fig_data(fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val") + self._check_fig_data(fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val") + + +class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.plot.test_frame_plot_plotly import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreach.py b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py index 5041fefff1909..b29338e7f59e7 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_foreach.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py @@ -220,9 +220,10 @@ def close(self, error): try: tester.run_streaming_query_on_writer(ForeachWriter(), 1) self.fail("bad writer did not fail the query") # this is not expected - except StreamingQueryException: - # TODO: Verify whether original error message is inside the exception - pass + except StreamingQueryException as e: + err_msg = str(e) + self.assertTrue("test error" in err_msg) + self.assertTrue("FOREACH_USER_FUNCTION_ERROR" in err_msg) self.assertEqual(len(tester.process_events()), 0) # no row was processed close_events = tester.close_events() diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 2bd66baaa2bfe..1972dd2804d98 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -18,11 +18,14 @@ from enum import Enum from itertools import chain +import datetime +import unittest + from pyspark.sql import Column, Row from pyspark.sql import functions as sf from pyspark.sql.types import StructType, StructField, IntegerType, LongType from pyspark.errors import AnalysisException, PySparkTypeError, PySparkValueError -from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, pandas_requirement_message class ColumnTestsMixin: @@ -280,6 +283,33 @@ def test_expr_str_representation(self): when_cond = sf.when(expression, sf.lit(None)) self.assertEqual(str(when_cond), "Column<'CASE WHEN foo THEN NULL END'>") + def test_lit_time_representation(self): + dt = datetime.date(2021, 3, 4) + self.assertEqual(str(sf.lit(dt)), "Column<'2021-03-04'>") + + ts = datetime.datetime(2021, 3, 4, 12, 34, 56, 1234) + self.assertEqual(str(sf.lit(ts)), "Column<'2021-03-04 12:34:56.001234'>") + + @unittest.skipIf(not have_pandas, pandas_requirement_message) + def test_lit_delta_representation(self): + for delta in [ + datetime.timedelta(days=1), + datetime.timedelta(hours=2), + datetime.timedelta(minutes=3), + datetime.timedelta(seconds=4), + datetime.timedelta(microseconds=5), + datetime.timedelta(days=2, hours=21, microseconds=908), + datetime.timedelta(days=1, minutes=-3, microseconds=-1001), + datetime.timedelta(days=1, hours=2, minutes=3, seconds=4, microseconds=5), + ]: + import pandas as pd + + # Column<'PT69H0.000908S'> or Column<'P2DT21H0M0.000908S'> + s = str(sf.lit(delta)) + + # Parse the ISO string representation and compare + self.assertTrue(pd.Timedelta(s[8:-2]).to_pytimedelta() == delta) + def test_enum_literals(self): class IntEnum(Enum): X = 1 diff --git a/python/pyspark/sql/tests/test_creation.py b/python/pyspark/sql/tests/test_creation.py index dfe66cdd3edf0..c6917aa234b41 100644 --- a/python/pyspark/sql/tests/test_creation.py +++ b/python/pyspark/sql/tests/test_creation.py @@ -15,7 +15,6 @@ # limitations under the License. # -import platform from decimal import Decimal import os import time @@ -111,11 +110,7 @@ def test_create_dataframe_from_pandas_with_dst(self): os.environ["TZ"] = orig_env_tz time.tzset() - # TODO(SPARK-43354): Re-enable test_create_dataframe_from_pandas_with_day_time_interval - @unittest.skipIf( - "pypy" in platform.python_implementation().lower() or not have_pandas, - "Fails in PyPy Python 3.8, should enable.", - ) + @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_create_dataframe_from_pandas_with_day_time_interval(self): # SPARK-37277: Test DayTimeIntervalType in createDataFrame without Arrow. import pandas as pd diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index a214b874f5ec0..8ec0839ec1fe4 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -39,6 +39,7 @@ PySparkTypeError, PySparkValueError, ) +from pyspark.testing import assertDataFrameEqual from pyspark.testing.sqlutils import ( ReusedSQLTestCase, have_pyarrow, @@ -955,6 +956,74 @@ def test_checkpoint_dataframe(self): self.spark.range(1).localCheckpoint().explain() self.assertIn("ExistingRDD", buf.getvalue()) + def test_transpose(self): + df = self.spark.createDataFrame([{"a": "x", "b": "y", "c": "z"}]) + + # default index column + transposed_df = df.transpose() + expected_schema = StructType( + [StructField("key", StringType(), False), StructField("x", StringType(), True)] + ) + expected_data = [Row(key="b", x="y"), Row(key="c", x="z")] + expected_df = self.spark.createDataFrame(expected_data, schema=expected_schema) + assertDataFrameEqual(transposed_df, expected_df, checkRowOrder=True) + + # specified index column + transposed_df = df.transpose("c") + expected_schema = StructType( + [StructField("key", StringType(), False), StructField("z", StringType(), True)] + ) + expected_data = [Row(key="a", z="x"), Row(key="b", z="y")] + expected_df = self.spark.createDataFrame(expected_data, schema=expected_schema) + assertDataFrameEqual(transposed_df, expected_df, checkRowOrder=True) + + # enforce transpose max values + with self.sql_conf({"spark.sql.transposeMaxValues": 0}): + with self.assertRaises(AnalysisException) as pe: + df.transpose().collect() + self.check_error( + exception=pe.exception, + errorClass="TRANSPOSE_EXCEED_ROW_LIMIT", + messageParameters={"maxValues": "0", "config": "spark.sql.transposeMaxValues"}, + ) + + # enforce ascending order based on index column values for transposed columns + df = self.spark.createDataFrame([{"a": "z"}, {"a": "y"}, {"a": "x"}]) + transposed_df = df.transpose() + expected_schema = StructType( + [ + StructField("key", StringType(), False), + StructField("x", StringType(), True), + StructField("y", StringType(), True), + StructField("z", StringType(), True), + ] + ) # z, y, x -> x, y, z + expected_df = self.spark.createDataFrame([], schema=expected_schema) + assertDataFrameEqual(transposed_df, expected_df, checkRowOrder=True) + + # enforce AtomicType Attribute for index column values + df = self.spark.createDataFrame([{"a": ["x", "x"], "b": "y", "c": "z"}]) + with self.assertRaises(AnalysisException) as pe: + df.transpose().collect() + self.check_error( + exception=pe.exception, + errorClass="TRANSPOSE_INVALID_INDEX_COLUMN", + messageParameters={ + "reason": "Index column must be of atomic type, " + "but found: ArrayType(StringType,true)" + }, + ) + + # enforce least common type for non-index columns + df = self.spark.createDataFrame([{"a": "x", "b": "y", "c": 1}]) + with self.assertRaises(AnalysisException) as pe: + df.transpose().collect() + self.check_error( + exception=pe.exception, + errorClass="TRANSPOSE_NO_LEAST_COMMON_TYPE", + messageParameters={"dt1": "STRING", "dt2": "BIGINT"}, + ) + class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase): def test_query_execution_unsupported_in_classic(self): diff --git a/python/pyspark/sql/tests/test_dataframe_query_context.py b/python/pyspark/sql/tests/test_dataframe_query_context.py index 3f31f1d62d73d..bf0cc021ca771 100644 --- a/python/pyspark/sql/tests/test_dataframe_query_context.py +++ b/python/pyspark/sql/tests/test_dataframe_query_context.py @@ -54,7 +54,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="__add__", @@ -70,7 +69,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="__sub__", @@ -86,7 +84,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="__mul__", @@ -102,7 +99,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="__mod__", @@ -118,7 +114,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="__eq__", @@ -134,7 +129,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="__lt__", @@ -150,7 +144,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="__le__", @@ -166,7 +159,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="__ge__", @@ -182,7 +174,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="__gt__", @@ -198,7 +189,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="eqNullSafe", @@ -214,7 +204,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="bitwiseOR", @@ -230,7 +219,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="bitwiseAND", @@ -246,7 +234,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="bitwiseXOR", @@ -279,7 +266,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="__add__", @@ -299,7 +285,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="__sub__", @@ -317,7 +302,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="__mul__", @@ -344,7 +328,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="__add__", @@ -360,7 +343,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="__sub__", @@ -376,7 +358,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="__mul__", @@ -407,7 +388,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="__add__", @@ -425,7 +405,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="__sub__", @@ -443,7 +422,6 @@ def test_dataframe_query_context(self): "expression": "'string'", "sourceType": '"STRING"', "targetType": '"BIGINT"', - "ansiConfig": '"spark.sql.ansi.enabled"', }, query_context_type=QueryContextType.DataFrame, fragment="__mul__", diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index f7f2485a43e16..a0ab9bc9c7d40 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -1326,8 +1326,8 @@ def check(resultDf, expected): self.assertEqual([r[0] for r in resultDf.collect()], expected) check(df.select(F.is_variant_null(v)), [False, False]) - check(df.select(F.schema_of_variant(v)), ["STRUCT", "STRUCT"]) - check(df.select(F.schema_of_variant_agg(v)), ["STRUCT"]) + check(df.select(F.schema_of_variant(v)), ["OBJECT", "OBJECT"]) + check(df.select(F.schema_of_variant_agg(v)), ["OBJECT"]) check(df.select(F.variant_get(v, "$.a", "int")), [1, None]) check(df.select(F.variant_get(v, "$.b", "int")), [None, 2]) @@ -1365,6 +1365,13 @@ def test_try_parse_json(self): self.assertEqual("""{"a":1}""", actual[0]["var"]) self.assertEqual(None, actual[1]["var"]) + def test_to_variant_object(self): + df = self.spark.createDataFrame([(1, {"a": 1})], "i int, v struct") + actual = df.select( + F.to_json(F.to_variant_object(df.v)).alias("var"), + ).collect() + self.assertEqual("""{"a":1}""", actual[0]["var"]) + def test_schema_of_csv(self): with self.assertRaises(PySparkTypeError) as pe: F.schema_of_csv(1) diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index 8431e9b3e35d4..140c7680b181b 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -374,6 +374,68 @@ def test_case_insensitive_dict(self): self.assertEqual(d2["BaR"], 3) self.assertEqual(d2["baz"], 3) + def test_arrow_batch_data_source(self): + import pyarrow as pa + + class ArrowBatchDataSource(DataSource): + """ + A data source for testing Arrow Batch Serialization + """ + + @classmethod + def name(cls): + return "arrowbatch" + + def schema(self): + return "key int, value string" + + def reader(self, schema: str): + return ArrowBatchDataSourceReader(schema, self.options) + + class ArrowBatchDataSourceReader(DataSourceReader): + def __init__(self, schema, options): + self.schema: str = schema + self.options = options + + def read(self, partition): + # Create Arrow Record Batch + keys = pa.array([1, 2, 3, 4, 5], type=pa.int32()) + values = pa.array(["one", "two", "three", "four", "five"], type=pa.string()) + schema = pa.schema([("key", pa.int32()), ("value", pa.string())]) + record_batch = pa.RecordBatch.from_arrays([keys, values], schema=schema) + yield record_batch + + def partitions(self): + # hardcoded number of partitions + num_part = 1 + return [InputPartition(i) for i in range(num_part)] + + self.spark.dataSource.register(ArrowBatchDataSource) + df = self.spark.read.format("arrowbatch").load() + expected_data = [ + Row(key=1, value="one"), + Row(key=2, value="two"), + Row(key=3, value="three"), + Row(key=4, value="four"), + Row(key=5, value="five"), + ] + assertDataFrameEqual(df, expected_data) + + with self.assertRaisesRegex( + PythonException, + "PySparkRuntimeError: \\[DATA_SOURCE_RETURN_SCHEMA_MISMATCH\\] Return schema" + " mismatch in the result from 'read' method\\. Expected: 1 columns, Found: 2 columns", + ): + self.spark.read.format("arrowbatch").schema("dummy int").load().show() + + with self.assertRaisesRegex( + PythonException, + "PySparkRuntimeError: \\[DATA_SOURCE_RETURN_SCHEMA_MISMATCH\\] Return schema mismatch" + " in the result from 'read' method\\. Expected: \\['key', 'dummy'\\] columns, Found:" + " \\['key', 'value'\\] columns", + ): + self.spark.read.format("arrowbatch").schema("key int, dummy string").load().show() + def test_data_source_type_mismatch(self): class TestDataSource(DataSource): @classmethod diff --git a/python/pyspark/sql/tests/test_python_streaming_datasource.py b/python/pyspark/sql/tests/test_python_streaming_datasource.py index 183b0ad80d9d4..fa14b37b57e62 100644 --- a/python/pyspark/sql/tests/test_python_streaming_datasource.py +++ b/python/pyspark/sql/tests/test_python_streaming_datasource.py @@ -152,6 +152,66 @@ def check_batch(df, batch_id): q.awaitTermination() self.assertIsNone(q.exception(), "No exception has to be propagated.") + def test_stream_reader_pyarrow(self): + import pyarrow as pa + + class TestStreamReader(DataSourceStreamReader): + def initialOffset(self): + return {"offset": 0} + + def latestOffset(self): + return {"offset": 2} + + def partitions(self, start, end): + # hardcoded number of partitions + num_part = 1 + return [InputPartition(i) for i in range(num_part)] + + def read(self, partition): + keys = pa.array([1, 2, 3, 4, 5], type=pa.int32()) + values = pa.array(["one", "two", "three", "four", "five"], type=pa.string()) + schema = pa.schema([("key", pa.int32()), ("value", pa.string())]) + record_batch = pa.RecordBatch.from_arrays([keys, values], schema=schema) + yield record_batch + + class TestDataSourcePyarrow(DataSource): + @classmethod + def name(cls): + return "testdatasourcepyarrow" + + def schema(self): + return "key int, value string" + + def streamReader(self, schema): + return TestStreamReader() + + self.spark.dataSource.register(TestDataSourcePyarrow) + df = self.spark.readStream.format("testdatasourcepyarrow").load() + + output_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_output") + checkpoint_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_checkpoint") + + q = ( + df.writeStream.format("json") + .option("checkpointLocation", checkpoint_dir.name) + .start(output_dir.name) + ) + while not q.recentProgress: + time.sleep(0.2) + q.stop() + q.awaitTermination() + + expected_data = [ + Row(key=1, value="one"), + Row(key=2, value="two"), + Row(key=3, value="three"), + Row(key=4, value="four"), + Row(key=5, value="five"), + ] + df = self.spark.read.json(output_dir.name) + + assertDataFrameEqual(df, expected_data) + def test_simple_stream_reader(self): class SimpleStreamReader(SimpleDataSourceStreamReader): def initialOffset(self): diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index a97734ee0fcef..5d9ec92cbc830 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -41,6 +41,7 @@ PythonException, UnknownException, SparkUpgradeException, + PySparkImportError, PySparkNotImplementedError, PySparkRuntimeError, ) @@ -115,6 +116,22 @@ def require_test_compiled() -> None: ) +def require_minimum_plotly_version() -> None: + """Raise ImportError if plotly is not installed""" + minimum_plotly_version = "4.8" + + try: + import plotly # noqa: F401 + except ImportError as error: + raise PySparkImportError( + errorClass="PACKAGE_NOT_INSTALLED", + messageParameters={ + "package_name": "plotly", + "minimum_version": str(minimum_plotly_version), + }, + ) from error + + class ForeachBatchFunction: """ This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps @@ -336,7 +353,7 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: def dispatch_df_method(f: FuncT) -> FuncT: """ - For the usecases of direct DataFrame.union(df, ...), it checks if self + For the use cases of direct DataFrame.method(df, ...), it checks if self is a Connect DataFrame or Classic DataFrame, and dispatches. """ @@ -363,8 +380,8 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: def dispatch_col_method(f: FuncT) -> FuncT: """ - For the usecases of direct Column.method(col, ...), it checks if self - is a Connect DataFrame or Classic DataFrame, and dispatches. + For the use cases of direct Column.method(col, ...), it checks if self + is a Connect Column or Classic Column, and dispatches. """ @functools.wraps(f) @@ -390,8 +407,9 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: def dispatch_window_method(f: FuncT) -> FuncT: """ - For the usecases of direct Window.method(col, ...), it checks if self - is a Connect Window or Classic Window, and dispatches. + For use cases of direct Window.method(col, ...), this function dispatches + the call to either ConnectWindow or ClassicWindow based on the execution + environment. """ @functools.wraps(f) @@ -405,11 +423,6 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: return getattr(ClassicWindow, f.__name__)(*args, **kwargs) - raise PySparkNotImplementedError( - errorClass="NOT_IMPLEMENTED", - messageParameters={"feature": f"Window.{f.__name__}"}, - ) - return cast(FuncT, wrapped) diff --git a/python/pyspark/sql/worker/plan_data_source_read.py b/python/pyspark/sql/worker/plan_data_source_read.py index 16b98ac0ed1e4..2af25fb52f150 100644 --- a/python/pyspark/sql/worker/plan_data_source_read.py +++ b/python/pyspark/sql/worker/plan_data_source_read.py @@ -19,7 +19,7 @@ import sys import functools import pyarrow as pa -from itertools import islice +from itertools import islice, chain from typing import IO, List, Iterator, Iterable, Tuple, Union from pyspark.accumulators import _accumulatorRegistry @@ -59,21 +59,18 @@ def records_to_arrow_batches( - output_iter: Iterator[Tuple], + output_iter: Union[Iterator[Tuple], Iterator[pa.RecordBatch]], max_arrow_batch_size: int, return_type: StructType, data_source: DataSource, ) -> Iterable[pa.RecordBatch]: """ - Convert an iterator of Python tuples to an iterator of pyarrow record batches. - - For each python tuple, check the types of each field and append it to the records batch. - + First check if the iterator yields PyArrow's `pyarrow.RecordBatch`, if so, yield + them directly. Otherwise, convert an iterator of Python tuples to an iterator + of pyarrow record batches. For each Python tuple, check the types of each field + and append it to the records batch. """ - def batched(iterator: Iterator, n: int) -> Iterator: - return iter(functools.partial(lambda it: list(islice(it, n)), iterator), []) - pa_schema = to_arrow_schema(return_type) column_names = return_type.fieldNames() column_converters = [ @@ -83,6 +80,45 @@ def batched(iterator: Iterator, n: int) -> Iterator: num_cols = len(column_names) col_mapping = {name: i for i, name in enumerate(column_names)} col_name_set = set(column_names) + + try: + first_element = next(output_iter) + except StopIteration: + return + + # If the first element is of type pa.RecordBatch yield all elements and return + if isinstance(first_element, pa.RecordBatch): + # Validate the schema, check the RecordBatch column count + num_columns = first_element.num_columns + if num_columns != num_cols: + raise PySparkRuntimeError( + errorClass="DATA_SOURCE_RETURN_SCHEMA_MISMATCH", + messageParameters={ + "expected": str(num_cols), + "actual": str(num_columns), + }, + ) + for name in column_names: + if name not in first_element.schema.names: + raise PySparkRuntimeError( + errorClass="DATA_SOURCE_RETURN_SCHEMA_MISMATCH", + messageParameters={ + "expected": str(column_names), + "actual": str(first_element.schema.names), + }, + ) + + yield first_element + for element in output_iter: + yield element + return + + # Put the first element back to the iterator + output_iter = chain([first_element], output_iter) + + def batched(iterator: Iterator, n: int) -> Iterator: + return iter(functools.partial(lambda it: list(islice(it, n)), iterator), []) + for batch in batched(output_iter, max_arrow_batch_size): pylist: List[List] = [[] for _ in range(num_cols)] for result in batch: @@ -103,7 +139,8 @@ def batched(iterator: Iterator, n: int) -> Iterator: messageParameters={ "type": type(result).__name__, "name": data_source.name(), - "supported_types": "tuple, list, `pyspark.sql.types.Row`", + "supported_types": "tuple, list, `pyspark.sql.types.Row`," + " `pyarrow.RecordBatch`", }, ) @@ -145,9 +182,10 @@ def main(infile: IO, outfile: IO) -> None: This process then creates a `DataSourceReader` instance by calling the `reader` method on the `DataSource` instance. Then it calls the `partitions()` method of the reader and - constructs a Python UDTF using the `read()` method of the reader. + constructs a PyArrow's `RecordBatch` with the data using the `read()` method of the reader. - The partition values and the UDTF are then serialized and sent back to the JVM via the socket. + The partition values and the Arrow Batch are then serialized and sent back to the JVM + via the socket. """ try: check_python_version(infile) diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py index 9f07c44c084cf..00ad40e68bd7c 100644 --- a/python/pyspark/testing/sqlutils.py +++ b/python/pyspark/testing/sqlutils.py @@ -48,6 +48,13 @@ except Exception as e: test_not_compiled_message = str(e) +plotly_requirement_message = None +try: + import plotly +except ImportError as e: + plotly_requirement_message = str(e) +have_plotly = plotly_requirement_message is None + from pyspark.sql import SparkSession from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row from pyspark.testing.utils import ReusedPySparkTestCase, PySparkErrorTestUtils diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index f33f2111c5a1e..5488d11d868f5 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -185,7 +185,7 @@ def setUpClass(cls): def tearDownClass(cls): cls.sc.stop() - def test_assert_vanilla_mode(self): + def test_assert_classic_mode(self): from pyspark.sql import is_remote self.assertFalse(is_remote()) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 205e3d957a415..cca44435efe67 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -262,10 +262,6 @@ def try_simplify_traceback(tb: TracebackType) -> Optional[TracebackType]: if "pypy" in platform.python_implementation().lower(): # Traceback modification is not supported with PyPy in PySpark. return None - if sys.version_info[:2] < (3, 7): - # Traceback creation is not supported Python < 3.7. - # See https://bugs.python.org/issue30579. - return None import pyspark @@ -791,7 +787,7 @@ def is_remote_only() -> bool: if __name__ == "__main__": - if "pypy" not in platform.python_implementation().lower() and sys.version_info[:2] >= (3, 7): + if "pypy" not in platform.python_implementation().lower() and sys.version_info[:2] >= (3, 9): import doctest import pyspark.util from pyspark.core.context import SparkContext diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index fa0fd454ccc44..211c6c93b9674 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -29,15 +29,11 @@ Spark Project Kubernetes kubernetes - **/*Volcano*.scala volcano - - - io.fabric8 @@ -50,6 +46,40 @@ ${kubernetes-client.version} + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-volcano-source + generate-sources + + add-source + + + + volcano/src/main/scala + + + + + add-volcano-test-sources + generate-test-sources + + add-test-source + + + + volcano/src/test/scala + + + + + + + @@ -151,19 +181,6 @@ - - - - net.alchim31.maven - scala-maven-plugin - - - ${volcano.exclude} - - - - - target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 393ffc5674011..3a4d68c19014d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -776,7 +776,7 @@ private[spark] object Config extends Logging { val KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY = "options.medium" val KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY = "options.sizeLimit" val KUBERNETES_VOLUMES_OPTIONS_SERVER_KEY = "options.server" - + val KUBERNETES_VOLUMES_LABEL_KEY = "label." val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." val KUBERNETES_DNS_SUBDOMAIN_NAME_MAX_LENGTH = 253 diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala index 3f7355de18911..9dfd40a773eb1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala @@ -24,7 +24,8 @@ private[spark] case class KubernetesHostPathVolumeConf(hostPath: String) private[spark] case class KubernetesPVCVolumeConf( claimName: String, storageClass: Option[String] = None, - size: Option[String] = None) + size: Option[String] = None, + labels: Option[Map[String, String]] = None) extends KubernetesVolumeSpecificConf private[spark] case class KubernetesEmptyDirVolumeConf( diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala index ee2108e8234d3..6463512c0114b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala @@ -45,13 +45,21 @@ object KubernetesVolumeUtils { val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_PATH_KEY" val readOnlyKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_READONLY_KEY" val subPathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY" + val labelKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_LABEL_KEY" + + val volumeLabelsMap = properties + .filter(_._1.startsWith(labelKey)) + .map { + case (k, v) => k.replaceAll(labelKey, "") -> v + } KubernetesVolumeSpec( volumeName = volumeName, mountPath = properties(pathKey), mountSubPath = properties.getOrElse(subPathKey, ""), mountReadOnly = properties.get(readOnlyKey).exists(_.toBoolean), - volumeConf = parseVolumeSpecificConf(properties, volumeType, volumeName)) + volumeConf = parseVolumeSpecificConf(properties, + volumeType, volumeName, Option(volumeLabelsMap))) }.toSeq } @@ -74,7 +82,8 @@ object KubernetesVolumeUtils { private def parseVolumeSpecificConf( options: Map[String, String], volumeType: String, - volumeName: String): KubernetesVolumeSpecificConf = { + volumeName: String, + labels: Option[Map[String, String]]): KubernetesVolumeSpecificConf = { volumeType match { case KUBERNETES_VOLUMES_HOSTPATH_TYPE => val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_PATH_KEY" @@ -91,7 +100,8 @@ object KubernetesVolumeUtils { KubernetesPVCVolumeConf( options(claimNameKey), options.get(storageClassKey), - options.get(sizeLimitKey)) + options.get(sizeLimitKey), + labels) case KUBERNETES_VOLUMES_EMPTYDIR_TYPE => val mediumKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala index 72cc012a6bdd0..5cc61c746b0e0 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala @@ -74,7 +74,7 @@ private[spark] class MountVolumesFeatureStep(conf: KubernetesConf) new VolumeBuilder() .withHostPath(new HostPathVolumeSource(hostPath, "")) - case KubernetesPVCVolumeConf(claimNameTemplate, storageClass, size) => + case KubernetesPVCVolumeConf(claimNameTemplate, storageClass, size, labels) => val claimName = conf match { case c: KubernetesExecutorConf => claimNameTemplate @@ -86,12 +86,17 @@ private[spark] class MountVolumesFeatureStep(conf: KubernetesConf) .replaceAll(PVC_ON_DEMAND, s"${conf.resourceNamePrefix}-driver$PVC_POSTFIX-$i") } if (storageClass.isDefined && size.isDefined) { + val defaultVolumeLabels = Map(SPARK_APP_ID_LABEL -> conf.appId) + val volumeLabels = labels match { + case Some(customLabelsMap) => (customLabelsMap ++ defaultVolumeLabels).asJava + case None => defaultVolumeLabels.asJava + } additionalResources.append(new PersistentVolumeClaimBuilder() .withKind(PVC) .withApiVersion("v1") .withNewMetadata() .withName(claimName) - .addToLabels(SPARK_APP_ID_LABEL, conf.appId) + .addToLabels(volumeLabels) .endMetadata() .withNewSpec() .withStorageClassName(storageClass.get) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala index b70b9348d23b4..7e0a65bcdda90 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala @@ -117,12 +117,17 @@ object KubernetesTestConf { (KUBERNETES_VOLUMES_HOSTPATH_TYPE, Map(KUBERNETES_VOLUMES_OPTIONS_PATH_KEY -> path)) - case KubernetesPVCVolumeConf(claimName, storageClass, sizeLimit) => + case KubernetesPVCVolumeConf(claimName, storageClass, sizeLimit, labels) => val sconf = storageClass .map { s => (KUBERNETES_VOLUMES_OPTIONS_CLAIM_STORAGE_CLASS_KEY, s) }.toMap val lconf = sizeLimit.map { l => (KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY, l) }.toMap + val llabels = labels match { + case Some(value) => value.map { case(k, v) => s"label.$k" -> v } + case None => Map() + } (KUBERNETES_VOLUMES_PVC_TYPE, - Map(KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY -> claimName) ++ sconf ++ lconf) + Map(KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY -> claimName) ++ + sconf ++ lconf ++ llabels) case KubernetesEmptyDirVolumeConf(medium, sizeLimit) => val mconf = medium.map { m => (KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY, m) }.toMap diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala index fdc1aae0d4109..5c103739d3082 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala @@ -56,7 +56,39 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { assert(volumeSpec.mountPath === "/path") assert(volumeSpec.mountReadOnly) assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === - KubernetesPVCVolumeConf("claimName")) + KubernetesPVCVolumeConf("claimName", labels = Some(Map()))) + } + + test("SPARK-49598: Parses persistentVolumeClaim volumes correctly with labels") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path") + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.readOnly", "true") + sparkConf.set("test.persistentVolumeClaim.volumeName.options.claimName", "claimName") + sparkConf.set("test.persistentVolumeClaim.volumeName.label.env", "test") + sparkConf.set("test.persistentVolumeClaim.volumeName.label.foo", "bar") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === + KubernetesPVCVolumeConf(claimName = "claimName", + labels = Some(Map("env" -> "test", "foo" -> "bar")))) + } + + test("SPARK-49598: Parses persistentVolumeClaim volumes & puts " + + "labels as empty Map if not provided") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path") + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.readOnly", "true") + sparkConf.set("test.persistentVolumeClaim.volumeName.options.claimName", "claimName") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === + KubernetesPVCVolumeConf(claimName = "claimName", labels = Some(Map()))) } test("Parses emptyDir volumes correctly") { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala index 54796def95e53..6a68898c5f61c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala @@ -131,6 +131,79 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { assert(pvcClaim.getClaimName.endsWith("-driver-pvc-0")) } + test("SPARK-49598: Create and mounts persistentVolumeClaims in driver with labels") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + "", + true, + KubernetesPVCVolumeConf(claimName = MountVolumesFeatureStep.PVC_ON_DEMAND, + storageClass = Some("gp3"), + size = Some("1Mi"), + labels = Some(Map("foo" -> "bar", "env" -> "test"))) + ) + + val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf)) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val pvcClaim = configuredPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(pvcClaim.getClaimName.endsWith("-driver-pvc-0")) + } + + test("SPARK-49598: Create and mounts persistentVolumeClaims in executors with labels") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + "", + true, + KubernetesPVCVolumeConf(claimName = MountVolumesFeatureStep.PVC_ON_DEMAND, + storageClass = Some("gp3"), + size = Some("1Mi"), + labels = Some(Map("foo1" -> "bar1", "env" -> "exec-test"))) + ) + + val executorConf = KubernetesTestConf.createExecutorConf(volumes = Seq(volumeConf)) + val executorStep = new MountVolumesFeatureStep(executorConf) + val executorPod = executorStep.configurePod(SparkPod.initialPod()) + + assert(executorPod.pod.getSpec.getVolumes.size() === 1) + val executorPVC = executorPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(executorPVC.getClaimName.endsWith("-exec-1-pvc-0")) + } + + test("SPARK-49598: Mount multiple volumes to executor with labels") { + val pvcVolumeConf1 = KubernetesVolumeSpec( + "checkpointVolume1", + "/checkpoints1", + "", + true, + KubernetesPVCVolumeConf(claimName = "pvcClaim1", + storageClass = Some("gp3"), + size = Some("1Mi"), + labels = Some(Map("foo1" -> "bar1", "env1" -> "exec-test-1"))) + ) + + val pvcVolumeConf2 = KubernetesVolumeSpec( + "checkpointVolume2", + "/checkpoints2", + "", + true, + KubernetesPVCVolumeConf(claimName = "pvcClaim2", + storageClass = Some("gp3"), + size = Some("1Mi"), + labels = Some(Map("foo2" -> "bar2", "env2" -> "exec-test-2"))) + ) + + val kubernetesConf = KubernetesTestConf.createExecutorConf( + volumes = Seq(pvcVolumeConf1, pvcVolumeConf2)) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 2) + assert(configuredPod.container.getVolumeMounts.size() === 2) + } + test("Create and mount persistentVolumeClaims in executors") { val volumeConf = KubernetesVolumeSpec( "testVolume", diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStep.scala b/resource-managers/kubernetes/core/volcano/src/main/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStep.scala similarity index 100% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStep.scala rename to resource-managers/kubernetes/core/volcano/src/main/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStep.scala diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStepSuite.scala b/resource-managers/kubernetes/core/volcano/src/test/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStepSuite.scala similarity index 100% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStepSuite.scala rename to resource-managers/kubernetes/core/volcano/src/test/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStepSuite.scala diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 518c5bc217071..45ce25b8e037a 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -46,7 +46,6 @@ org.apache.spark.deploy.k8s.integrationtest.YuniKornTag - **/*Volcano*.scala jar Spark Project Kubernetes Integration Tests @@ -83,19 +82,6 @@ - - - - net.alchim31.maven - scala-maven-plugin - - - ${volcano.exclude} - - - - - org.codehaus.mojo @@ -219,9 +205,6 @@ volcano - - - io.fabric8 @@ -229,6 +212,28 @@ ${kubernetes-client.version} + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-volcano-test-sources + generate-test-sources + + add-test-source + + + + volcano/src/test/scala + + + + + + + diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoSuite.scala b/resource-managers/kubernetes/integration-tests/volcano/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoSuite.scala similarity index 100% rename from resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoSuite.scala rename to resource-managers/kubernetes/integration-tests/volcano/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoSuite.scala diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoTestsSuite.scala b/resource-managers/kubernetes/integration-tests/volcano/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoTestsSuite.scala similarity index 100% rename from resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoTestsSuite.scala rename to resource-managers/kubernetes/integration-tests/volcano/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoTestsSuite.scala diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index acfc0011f5d05..de28041acd41f 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -146,6 +146,7 @@ BUCKETS: 'BUCKETS'; BY: 'BY'; BYTE: 'BYTE'; CACHE: 'CACHE'; +CALL: 'CALL'; CALLED: 'CALLED'; CASCADE: 'CASCADE'; CASE: 'CASE'; @@ -255,12 +256,14 @@ BINARY_HEX: 'X'; HOUR: 'HOUR'; HOURS: 'HOURS'; IDENTIFIER_KW: 'IDENTIFIER'; +IDENTITY: 'IDENTITY'; IF: 'IF'; IGNORE: 'IGNORE'; IMMEDIATE: 'IMMEDIATE'; IMPORT: 'IMPORT'; IN: 'IN'; INCLUDE: 'INCLUDE'; +INCREMENT: 'INCREMENT'; INDEX: 'INDEX'; INDEXES: 'INDEXES'; INNER: 'INNER'; @@ -276,6 +279,7 @@ INTO: 'INTO'; INVOKER: 'INVOKER'; IS: 'IS'; ITEMS: 'ITEMS'; +ITERATE: 'ITERATE'; JOIN: 'JOIN'; KEYS: 'KEYS'; LANGUAGE: 'LANGUAGE'; @@ -283,6 +287,7 @@ LAST: 'LAST'; LATERAL: 'LATERAL'; LAZY: 'LAZY'; LEADING: 'LEADING'; +LEAVE: 'LEAVE'; LEFT: 'LEFT'; LIKE: 'LIKE'; ILIKE: 'ILIKE'; @@ -362,6 +367,7 @@ REFERENCES: 'REFERENCES'; REFRESH: 'REFRESH'; RENAME: 'RENAME'; REPAIR: 'REPAIR'; +REPEAT: 'REPEAT'; REPEATABLE: 'REPEATABLE'; REPLACE: 'REPLACE'; RESET: 'RESET'; @@ -451,6 +457,7 @@ UNKNOWN: 'UNKNOWN'; UNLOCK: 'UNLOCK'; UNPIVOT: 'UNPIVOT'; UNSET: 'UNSET'; +UNTIL: 'UNTIL'; UPDATE: 'UPDATE'; USE: 'USE'; USER: 'USER'; @@ -501,6 +508,7 @@ TILDE: '~'; AMPERSAND: '&'; PIPE: '|'; CONCAT_PIPE: '||'; +OPERATOR_PIPE: '|>'; HAT: '^'; COLON: ':'; DOUBLE_COLON: '::'; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 5b8805821b045..094f7f5315b80 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -64,7 +64,11 @@ compoundStatement | setStatementWithOptionalVarKeyword | beginEndCompoundBlock | ifElseStatement + | caseStatement | whileStatement + | repeatStatement + | leaveStatement + | iterateStatement ; setStatementWithOptionalVarKeyword @@ -83,6 +87,25 @@ ifElseStatement (ELSE elseBody=compoundBody)? END IF ; +repeatStatement + : beginLabel? REPEAT compoundBody UNTIL booleanExpression END REPEAT endLabel? + ; + +leaveStatement + : LEAVE multipartIdentifier + ; + +iterateStatement + : ITERATE multipartIdentifier + ; + +caseStatement + : CASE (WHEN conditions+=booleanExpression THEN conditionalBodies+=compoundBody)+ + (ELSE elseBody=compoundBody)? END CASE #searchedCaseStatement + | CASE caseVariable=expression (WHEN conditionExpressions+=expression THEN conditionalBodies+=compoundBody)+ + (ELSE elseBody=compoundBody)? END CASE #simpleCaseStatement + ; + singleStatement : (statement|setResetStatement) SEMICOLON* EOF ; @@ -275,6 +298,10 @@ statement LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN (OPTIONS options=propertyList)? #createIndex | DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference #dropIndex + | CALL identifierReference + LEFT_PAREN + (functionArgument (COMMA functionArgument)*)? + RIGHT_PAREN #call | unsupportedHiveNativeCommands .*? #failNativeCommand ; @@ -589,6 +616,7 @@ queryTerm operator=INTERSECT setQuantifier? right=queryTerm #setOperation | left=queryTerm {!legacy_setops_precedence_enabled}? operator=(UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation + | left=queryTerm OPERATOR_PIPE operatorPipeRightSide #operatorPipeStatement ; queryPrimary @@ -1272,7 +1300,22 @@ colDefinitionOption ; generationExpression - : GENERATED ALWAYS AS LEFT_PAREN expression RIGHT_PAREN + : GENERATED ALWAYS AS LEFT_PAREN expression RIGHT_PAREN #generatedColumn + | GENERATED (ALWAYS | BY DEFAULT) AS IDENTITY identityColSpec? #identityColumn + ; + +identityColSpec + : LEFT_PAREN sequenceGeneratorOption* RIGHT_PAREN + ; + +sequenceGeneratorOption + : START WITH start=sequenceGeneratorStartOrStep + | INCREMENT BY step=sequenceGeneratorStartOrStep + ; + +sequenceGeneratorStartOrStep + : MINUS? INTEGER_VALUE + | MINUS? BIGINT_LITERAL ; complexColTypeList @@ -1447,6 +1490,11 @@ version | stringLit ; +operatorPipeRightSide + : selectClause + | whereClause + ; + // When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL. // - Reserved keywords: // Keywords that are reserved and can't be used as identifiers for table, view, column, @@ -1562,11 +1610,13 @@ ansiNonReserved | HOUR | HOURS | IDENTIFIER_KW + | IDENTITY | IF | IGNORE | IMMEDIATE | IMPORT | INCLUDE + | INCREMENT | INDEX | INDEXES | INPATH @@ -1578,10 +1628,12 @@ ansiNonReserved | INTERVAL | INVOKER | ITEMS + | ITERATE | KEYS | LANGUAGE | LAST | LAZY + | LEAVE | LIKE | ILIKE | LIMIT @@ -1648,6 +1700,7 @@ ansiNonReserved | REFRESH | RENAME | REPAIR + | REPEAT | REPEATABLE | REPLACE | RESET @@ -1723,6 +1776,7 @@ ansiNonReserved | UNLOCK | UNPIVOT | UNSET + | UNTIL | UPDATE | USE | VALUES @@ -1802,6 +1856,7 @@ nonReserved | BY | BYTE | CACHE + | CALL | CALLED | CASCADE | CASE @@ -1908,12 +1963,14 @@ nonReserved | HOUR | HOURS | IDENTIFIER_KW + | IDENTITY | IF | IGNORE | IMMEDIATE | IMPORT | IN | INCLUDE + | INCREMENT | INDEX | INDEXES | INPATH @@ -1927,11 +1984,13 @@ nonReserved | INVOKER | IS | ITEMS + | ITERATE | KEYS | LANGUAGE | LAST | LAZY | LEADING + | LEAVE | LIKE | LONG | ILIKE @@ -2009,6 +2068,7 @@ nonReserved | REFRESH | RENAME | REPAIR + | REPEAT | REPEATABLE | REPLACE | RESET @@ -2093,6 +2153,7 @@ nonReserved | UNLOCK | UNPIVOT | UNSET + | UNTIL | UPDATE | USE | USER diff --git a/sql/api/src/main/java/org/apache/spark/sql/connector/catalog/IdentityColumnSpec.java b/sql/api/src/main/java/org/apache/spark/sql/connector/catalog/IdentityColumnSpec.java new file mode 100644 index 0000000000000..4a8943736bd31 --- /dev/null +++ b/sql/api/src/main/java/org/apache/spark/sql/connector/catalog/IdentityColumnSpec.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog; +import org.apache.spark.annotation.Evolving; + +import java.util.Objects; + +/** + * Identity column specification. + */ +@Evolving +public class IdentityColumnSpec { + private final long start; + private final long step; + private final boolean allowExplicitInsert; + + /** + * Creates an identity column specification. + * @param start the start value to generate the identity values + * @param step the step value to generate the identity values + * @param allowExplicitInsert whether the identity column allows explicit insertion of values + */ + public IdentityColumnSpec(long start, long step, boolean allowExplicitInsert) { + this.start = start; + this.step = step; + this.allowExplicitInsert = allowExplicitInsert; + } + + /** + * @return the start value to generate the identity values + */ + public long getStart() { + return start; + } + + /** + * @return the step value to generate the identity values + */ + public long getStep() { + return step; + } + + /** + * @return whether the identity column allows explicit insertion of values + */ + public boolean isAllowExplicitInsert() { + return allowExplicitInsert; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + IdentityColumnSpec that = (IdentityColumnSpec) o; + return start == that.start && + step == that.step && + allowExplicitInsert == that.allowExplicitInsert; + } + + @Override + public int hashCode() { + return Objects.hash(start, step, allowExplicitInsert); + } + + @Override + public String toString() { + return "IdentityColumnSpec{" + + "start=" + start + + ", step=" + step + + ", allowExplicitInsert=" + allowExplicitInsert + + "}"; + } +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/api/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 7a428f6cc3288..a2c1f2cc41f8f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} * @since 1.3.0 */ @Stable -class AnalysisException protected( +class AnalysisException protected ( val message: String, val line: Option[Int] = None, val startPosition: Option[Int] = None, @@ -37,12 +37,12 @@ class AnalysisException protected( val errorClass: Option[String] = None, val messageParameters: Map[String, String] = Map.empty, val context: Array[QueryContext] = Array.empty) - extends Exception(message, cause.orNull) with SparkThrowable with Serializable with WithOrigin { + extends Exception(message, cause.orNull) + with SparkThrowable + with Serializable + with WithOrigin { - def this( - errorClass: String, - messageParameters: Map[String, String], - cause: Option[Throwable]) = + def this(errorClass: String, messageParameters: Map[String, String], cause: Option[Throwable]) = this( SparkThrowableHelper.getMessage(errorClass, messageParameters), errorClass = Some(errorClass), @@ -73,18 +73,10 @@ class AnalysisException protected( cause = null, context = context) - def this( - errorClass: String, - messageParameters: Map[String, String]) = - this( - errorClass = errorClass, - messageParameters = messageParameters, - cause = None) + def this(errorClass: String, messageParameters: Map[String, String]) = + this(errorClass = errorClass, messageParameters = messageParameters, cause = None) - def this( - errorClass: String, - messageParameters: Map[String, String], - origin: Origin) = + def this(errorClass: String, messageParameters: Map[String, String], origin: Origin) = this( SparkThrowableHelper.getMessage(errorClass, messageParameters), line = origin.line, @@ -115,8 +107,14 @@ class AnalysisException protected( errorClass: Option[String] = this.errorClass, messageParameters: Map[String, String] = this.messageParameters, context: Array[QueryContext] = this.context): AnalysisException = - new AnalysisException(message, line, startPosition, cause, errorClass, - messageParameters, context) + new AnalysisException( + message, + line, + startPosition, + cause, + errorClass, + messageParameters, + context) def withPosition(origin: Origin): AnalysisException = { val newException = this.copy( diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Artifact.scala b/sql/api/src/main/scala/org/apache/spark/sql/Artifact.scala index c78280af6e021..7e020df06fe47 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/Artifact.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/Artifact.scala @@ -28,8 +28,7 @@ import org.apache.spark.sql.util.ArtifactUtils import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.MavenUtils - -private[sql] class Artifact private(val path: Path, val storage: LocalData) { +private[sql] class Artifact private (val path: Path, val storage: LocalData) { require(!path.isAbsolute, s"Bad path: $path") lazy val size: Long = storage match { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Column.scala b/sql/api/src/main/scala/org/apache/spark/sql/Column.scala index cd6a04b2a0562..31ce44eca1684 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/Column.scala @@ -72,34 +72,34 @@ private[spark] object Column { isDistinct: Boolean, isInternal: Boolean, inputs: Seq[Column]): Column = withOrigin { - Column(internal.UnresolvedFunction( - name, - inputs.map(_.node), - isDistinct = isDistinct, - isInternal = isInternal)) + Column( + internal.UnresolvedFunction( + name, + inputs.map(_.node), + isDistinct = isDistinct, + isInternal = isInternal)) } } /** - * A [[Column]] where an [[Encoder]] has been given for the expected input and return type. - * To create a [[TypedColumn]], use the `as` function on a [[Column]]. + * A [[Column]] where an [[Encoder]] has been given for the expected input and return type. To + * create a [[TypedColumn]], use the `as` function on a [[Column]]. * - * @tparam T The input type expected for this expression. Can be `Any` if the expression is type - * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). - * @tparam U The output type of this column. + * @tparam T + * The input type expected for this expression. Can be `Any` if the expression is type checked + * by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). + * @tparam U + * The output type of this column. * * @since 1.6.0 */ @Stable -class TypedColumn[-T, U]( - node: ColumnNode, - private[sql] val encoder: Encoder[U]) - extends Column(node) { +class TypedColumn[-T, U](node: ColumnNode, private[sql] val encoder: Encoder[U]) + extends Column(node) { /** - * Gives the [[TypedColumn]] a name (alias). - * If the current `TypedColumn` has metadata associated with it, this metadata will be propagated - * to the new column. + * Gives the [[TypedColumn]] a name (alias). If the current `TypedColumn` has metadata + * associated with it, this metadata will be propagated to the new column. * * @group expr_ops * @since 2.0.0 @@ -168,23 +168,20 @@ class Column(val node: ColumnNode) extends Logging { override def hashCode: Int = this.node.normalized.hashCode() /** - * Provides a type hint about the expected return value of this column. This information can - * be used by operations such as `select` on a [[Dataset]] to automatically convert the - * results into the correct JVM types. + * Provides a type hint about the expected return value of this column. This information can be + * used by operations such as `select` on a [[Dataset]] to automatically convert the results + * into the correct JVM types. * @since 1.6.0 */ - def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](node, implicitly[Encoder[U]]) + def as[U: Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](node, implicitly[Encoder[U]]) /** - * Extracts a value or values from a complex type. - * The following types of extraction are supported: - *
    - *
  • Given an Array, an integer ordinal can be used to retrieve a single value.
  • - *
  • Given a Map, a key of the correct type can be used to retrieve an individual value.
  • - *
  • Given a Struct, a string fieldName can be used to extract that field.
  • - *
  • Given an Array of Structs, a string fieldName can be used to extract filed - * of every struct in that array, and return an Array of fields.
  • - *
+ * Extracts a value or values from a complex type. The following types of extraction are + * supported:
  • Given an Array, an integer ordinal can be used to retrieve a single + * value.
  • Given a Map, a key of the correct type can be used to retrieve an individual + * value.
  • Given a Struct, a string fieldName can be used to extract that field.
  • + *
  • Given an Array of Structs, a string fieldName can be used to extract filed of every + * struct in that array, and return an Array of fields.
* @group expr_ops * @since 1.4.0 */ @@ -283,8 +280,8 @@ class Column(val node: ColumnNode) extends Logging { * * @group expr_ops * @since 2.0.0 - */ - def =!= (other: Any): Column = !(this === other) + */ + def =!=(other: Any): Column = !(this === other) /** * Inequality test. @@ -300,9 +297,9 @@ class Column(val node: ColumnNode) extends Logging { * * @group expr_ops * @since 1.3.0 - */ + */ @deprecated("!== does not have the same precedence as ===, use =!= instead", "2.0.0") - def !== (other: Any): Column = this =!= other + def !==(other: Any): Column = this =!= other /** * Inequality test. @@ -464,8 +461,8 @@ class Column(val node: ColumnNode) extends Logging { def eqNullSafe(other: Any): Column = this <=> other /** - * Evaluates a list of conditions and returns one of multiple possible result expressions. - * If otherwise is not defined at the end, null is returned for unmatched conditions. + * Evaluates a list of conditions and returns one of multiple possible result expressions. If + * otherwise is not defined at the end, null is returned for unmatched conditions. * * {{{ * // Example: encoding gender string column into integer. @@ -489,8 +486,7 @@ class Column(val node: ColumnNode) extends Logging { case internal.CaseWhenOtherwise(branches, None, _) => internal.CaseWhenOtherwise(branches :+ ((condition.node, lit(value).node)), None) case internal.CaseWhenOtherwise(_, Some(_), _) => - throw new IllegalArgumentException( - "when() cannot be applied once otherwise() is applied") + throw new IllegalArgumentException("when() cannot be applied once otherwise() is applied") case _ => throw new IllegalArgumentException( "when() can only be applied on a Column previously generated by when() function") @@ -498,8 +494,8 @@ class Column(val node: ColumnNode) extends Logging { } /** - * Evaluates a list of conditions and returns one of multiple possible result expressions. - * If otherwise is not defined at the end, null is returned for unmatched conditions. + * Evaluates a list of conditions and returns one of multiple possible result expressions. If + * otherwise is not defined at the end, null is returned for unmatched conditions. * * {{{ * // Example: encoding gender string column into integer. @@ -765,13 +761,11 @@ class Column(val node: ColumnNode) extends Logging { * A boolean expression that is evaluated to true if the value of this expression is contained * by the evaluated values of the arguments. * - * Note: Since the type of the elements in the list are inferred only during the run time, - * the elements will be "up-casted" to the most common type for comparison. - * For eg: - * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the - * comparison will look like "String vs String". - * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the - * comparison will look like "Double vs Double" + * Note: Since the type of the elements in the list are inferred only during the run time, the + * elements will be "up-casted" to the most common type for comparison. For eg: 1) In the case + * of "Int vs String", the "Int" will be up-casted to "String" and the comparison will look like + * "String vs String". 2) In the case of "Float vs Double", the "Float" will be up-casted to + * "Double" and the comparison will look like "Double vs Double" * * @group expr_ops * @since 1.5.0 @@ -784,12 +778,10 @@ class Column(val node: ColumnNode) extends Logging { * by the provided collection. * * Note: Since the type of the elements in the collection are inferred only during the run time, - * the elements will be "up-casted" to the most common type for comparison. - * For eg: - * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the - * comparison will look like "String vs String". - * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the - * comparison will look like "Double vs Double" + * the elements will be "up-casted" to the most common type for comparison. For eg: 1) In the + * case of "Int vs String", the "Int" will be up-casted to "String" and the comparison will look + * like "String vs String". 2) In the case of "Float vs Double", the "Float" will be up-casted + * to "Double" and the comparison will look like "Double vs Double" * * @group expr_ops * @since 2.4.0 @@ -801,12 +793,10 @@ class Column(val node: ColumnNode) extends Logging { * by the provided collection. * * Note: Since the type of the elements in the collection are inferred only during the run time, - * the elements will be "up-casted" to the most common type for comparison. - * For eg: - * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the - * comparison will look like "String vs String". - * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the - * comparison will look like "Double vs Double" + * the elements will be "up-casted" to the most common type for comparison. For eg: 1) In the + * case of "Int vs String", the "Int" will be up-casted to "String" and the comparison will look + * like "String vs String". 2) In the case of "Float vs Double", the "Float" will be up-casted + * to "Double" and the comparison will look like "Double vs Double" * * @group java_expr_ops * @since 2.4.0 @@ -822,8 +812,7 @@ class Column(val node: ColumnNode) extends Logging { def like(literal: String): Column = fn("like", literal) /** - * SQL RLIKE expression (LIKE with Regex). Returns a boolean column based on a regex - * match. + * SQL RLIKE expression (LIKE with Regex). Returns a boolean column based on a regex match. * * @group expr_ops * @since 1.3.0 @@ -839,8 +828,8 @@ class Column(val node: ColumnNode) extends Logging { def ilike(literal: String): Column = fn("ilike", literal) /** - * An expression that gets an item at position `ordinal` out of an array, - * or gets a value by key `key` in a `MapType`. + * An expression that gets an item at position `ordinal` out of an array, or gets a value by key + * `key` in a `MapType`. * * @group expr_ops * @since 1.3.0 @@ -885,8 +874,8 @@ class Column(val node: ColumnNode) extends Logging { * // result: {"a":{"a":1,"b":2,"c":3,"d":4}} * }}} * - * However, if you are going to add/replace multiple nested fields, it is more optimal to extract - * out the nested struct before adding/replacing multiple fields e.g. + * However, if you are going to add/replace multiple nested fields, it is more optimal to + * extract out the nested struct before adding/replacing multiple fields e.g. * * {{{ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") @@ -906,8 +895,8 @@ class Column(val node: ColumnNode) extends Logging { // scalastyle:off line.size.limit /** - * An expression that drops fields in `StructType` by name. - * This is a no-op if schema doesn't contain field name(s). + * An expression that drops fields in `StructType` by name. This is a no-op if schema doesn't + * contain field name(s). * * {{{ * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") @@ -951,8 +940,8 @@ class Column(val node: ColumnNode) extends Logging { * // result: {"a":{"a":1}} * }}} * - * However, if you are going to drop multiple nested fields, it is more optimal to extract - * out the nested struct before dropping multiple fields from it e.g. + * However, if you are going to drop multiple nested fields, it is more optimal to extract out + * the nested struct before dropping multiple fields from it e.g. * * {{{ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") @@ -980,8 +969,10 @@ class Column(val node: ColumnNode) extends Logging { /** * An expression that returns a substring. - * @param startPos expression for the starting position. - * @param len expression for the length of the substring. + * @param startPos + * expression for the starting position. + * @param len + * expression for the length of the substring. * * @group expr_ops * @since 1.3.0 @@ -990,8 +981,10 @@ class Column(val node: ColumnNode) extends Logging { /** * An expression that returns a substring. - * @param startPos starting position. - * @param len length of the substring. + * @param startPos + * starting position. + * @param len + * length of the substring. * * @group expr_ops * @since 1.3.0 @@ -1057,9 +1050,9 @@ class Column(val node: ColumnNode) extends Logging { * df.select($"colA".as("colB")) * }}} * - * If the current column has metadata associated with it, this metadata will be propagated - * to the new column. If this not desired, use the API `as(alias: String, metadata: Metadata)` - * with explicit metadata. + * If the current column has metadata associated with it, this metadata will be propagated to + * the new column. If this not desired, use the API `as(alias: String, metadata: Metadata)` with + * explicit metadata. * * @group expr_ops * @since 1.3.0 @@ -1097,9 +1090,9 @@ class Column(val node: ColumnNode) extends Logging { * df.select($"colA".as("colB")) * }}} * - * If the current column has metadata associated with it, this metadata will be propagated - * to the new column. If this not desired, use the API `as(alias: String, metadata: Metadata)` - * with explicit metadata. + * If the current column has metadata associated with it, this metadata will be propagated to + * the new column. If this not desired, use the API `as(alias: String, metadata: Metadata)` with + * explicit metadata. * * @group expr_ops * @since 1.3.0 @@ -1126,9 +1119,9 @@ class Column(val node: ColumnNode) extends Logging { * df.select($"colA".name("colB")) * }}} * - * If the current column has metadata associated with it, this metadata will be propagated - * to the new column. If this not desired, use the API `as(alias: String, metadata: Metadata)` - * with explicit metadata. + * If the current column has metadata associated with it, this metadata will be propagated to + * the new column. If this not desired, use the API `as(alias: String, metadata: Metadata)` with + * explicit metadata. * * @group expr_ops * @since 2.0.0 @@ -1152,9 +1145,9 @@ class Column(val node: ColumnNode) extends Logging { def cast(to: DataType): Column = Column(internal.Cast(node, to)) /** - * Casts the column to a different data type, using the canonical string representation - * of the type. The supported types are: `string`, `boolean`, `byte`, `short`, `int`, `long`, - * `float`, `double`, `decimal`, `date`, `timestamp`. + * Casts the column to a different data type, using the canonical string representation of the + * type. The supported types are: `string`, `boolean`, `byte`, `short`, `int`, `long`, `float`, + * `double`, `decimal`, `date`, `timestamp`. * {{{ * // Casts colA to integer. * df.select(df("colA").cast("int")) @@ -1224,8 +1217,8 @@ class Column(val node: ColumnNode) extends Logging { def desc: Column = desc_nulls_last /** - * Returns a sort expression based on the descending order of the column, - * and null values appear before non-null values. + * Returns a sort expression based on the descending order of the column, and null values appear + * before non-null values. * {{{ * // Scala: sort a DataFrame by age column in descending order and null values appearing first. * df.sort(df("age").desc_nulls_first) @@ -1237,13 +1230,12 @@ class Column(val node: ColumnNode) extends Logging { * @group expr_ops * @since 2.1.0 */ - def desc_nulls_first: Column = sortOrder( - internal.SortOrder.Descending, - internal.SortOrder.NullsFirst) + def desc_nulls_first: Column = + sortOrder(internal.SortOrder.Descending, internal.SortOrder.NullsFirst) /** - * Returns a sort expression based on the descending order of the column, - * and null values appear after non-null values. + * Returns a sort expression based on the descending order of the column, and null values appear + * after non-null values. * {{{ * // Scala: sort a DataFrame by age column in descending order and null values appearing last. * df.sort(df("age").desc_nulls_last) @@ -1255,9 +1247,8 @@ class Column(val node: ColumnNode) extends Logging { * @group expr_ops * @since 2.1.0 */ - def desc_nulls_last: Column = sortOrder( - internal.SortOrder.Descending, - internal.SortOrder.NullsLast) + def desc_nulls_last: Column = + sortOrder(internal.SortOrder.Descending, internal.SortOrder.NullsLast) /** * Returns a sort expression based on ascending order of the column. @@ -1275,8 +1266,8 @@ class Column(val node: ColumnNode) extends Logging { def asc: Column = asc_nulls_first /** - * Returns a sort expression based on ascending order of the column, - * and null values return before non-null values. + * Returns a sort expression based on ascending order of the column, and null values return + * before non-null values. * {{{ * // Scala: sort a DataFrame by age column in ascending order and null values appearing first. * df.sort(df("age").asc_nulls_first) @@ -1288,13 +1279,12 @@ class Column(val node: ColumnNode) extends Logging { * @group expr_ops * @since 2.1.0 */ - def asc_nulls_first: Column = sortOrder( - internal.SortOrder.Ascending, - internal.SortOrder.NullsFirst) + def asc_nulls_first: Column = + sortOrder(internal.SortOrder.Ascending, internal.SortOrder.NullsFirst) /** - * Returns a sort expression based on ascending order of the column, - * and null values appear after non-null values. + * Returns a sort expression based on ascending order of the column, and null values appear + * after non-null values. * {{{ * // Scala: sort a DataFrame by age column in ascending order and null values appearing last. * df.sort(df("age").asc_nulls_last) @@ -1306,9 +1296,8 @@ class Column(val node: ColumnNode) extends Logging { * @group expr_ops * @since 2.1.0 */ - def asc_nulls_last: Column = sortOrder( - internal.SortOrder.Ascending, - internal.SortOrder.NullsLast) + def asc_nulls_last: Column = + sortOrder(internal.SortOrder.Ascending, internal.SortOrder.NullsLast) /** * Prints the expression to the console for debugging purposes. @@ -1378,8 +1367,8 @@ class Column(val node: ColumnNode) extends Logging { } /** - * Defines an empty analytic clause. In this case the analytic function is applied - * and presented for all rows in the result set. + * Defines an empty analytic clause. In this case the analytic function is applied and presented + * for all rows in the result set. * * {{{ * df.select( @@ -1395,7 +1384,6 @@ class Column(val node: ColumnNode) extends Logging { } - /** * A convenient class used for constructing schema. * diff --git a/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 96855ee5ad164..1838c6bc8468f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -34,15 +34,11 @@ import org.apache.spark.sql.errors.CompilationErrors abstract class DataFrameWriter[T] { /** - * Specifies the behavior when data or table already exists. Options include: - *
    - *
  • `SaveMode.Overwrite`: overwrite the existing data.
  • - *
  • `SaveMode.Append`: append the data.
  • - *
  • `SaveMode.Ignore`: ignore the operation (i.e. no-op).
  • - *
  • `SaveMode.ErrorIfExists`: throw an exception at runtime.
  • - *
- *

- * The default option is `ErrorIfExists`. + * Specifies the behavior when data or table already exists. Options include:

    + *
  • `SaveMode.Overwrite`: overwrite the existing data.
  • `SaveMode.Append`: append the + * data.
  • `SaveMode.Ignore`: ignore the operation (i.e. no-op).
  • + *
  • `SaveMode.ErrorIfExists`: throw an exception at runtime.

The default + * option is `ErrorIfExists`. * * @since 1.4.0 */ @@ -52,13 +48,10 @@ abstract class DataFrameWriter[T] { } /** - * Specifies the behavior when data or table already exists. Options include: - *

    - *
  • `overwrite`: overwrite the existing data.
  • - *
  • `append`: append the data.
  • - *
  • `ignore`: ignore the operation (i.e. no-op).
  • - *
  • `error` or `errorifexists`: default option, throw an exception at runtime.
  • - *
+ * Specifies the behavior when data or table already exists. Options include:
    + *
  • `overwrite`: overwrite the existing data.
  • `append`: append the data.
  • + *
  • `ignore`: ignore the operation (i.e. no-op).
  • `error` or `errorifexists`: default + * option, throw an exception at runtime.
* * @since 1.4.0 */ @@ -85,8 +78,8 @@ abstract class DataFrameWriter[T] { /** * Adds an output option for the underlying data source. * - * All options are maintained in a case-insensitive way in terms of key names. - * If a new option has the same key case-insensitively, it will override the existing option. + * All options are maintained in a case-insensitive way in terms of key names. If a new option + * has the same key case-insensitively, it will override the existing option. * * @since 1.4.0 */ @@ -98,8 +91,8 @@ abstract class DataFrameWriter[T] { /** * Adds an output option for the underlying data source. * - * All options are maintained in a case-insensitive way in terms of key names. - * If a new option has the same key case-insensitively, it will override the existing option. + * All options are maintained in a case-insensitive way in terms of key names. If a new option + * has the same key case-insensitively, it will override the existing option. * * @since 2.0.0 */ @@ -108,8 +101,8 @@ abstract class DataFrameWriter[T] { /** * Adds an output option for the underlying data source. * - * All options are maintained in a case-insensitive way in terms of key names. - * If a new option has the same key case-insensitively, it will override the existing option. + * All options are maintained in a case-insensitive way in terms of key names. If a new option + * has the same key case-insensitively, it will override the existing option. * * @since 2.0.0 */ @@ -118,8 +111,8 @@ abstract class DataFrameWriter[T] { /** * Adds an output option for the underlying data source. * - * All options are maintained in a case-insensitive way in terms of key names. - * If a new option has the same key case-insensitively, it will override the existing option. + * All options are maintained in a case-insensitive way in terms of key names. If a new option + * has the same key case-insensitively, it will override the existing option. * * @since 2.0.0 */ @@ -128,8 +121,8 @@ abstract class DataFrameWriter[T] { /** * (Scala-specific) Adds output options for the underlying data source. * - * All options are maintained in a case-insensitive way in terms of key names. - * If a new option has the same key case-insensitively, it will override the existing option. + * All options are maintained in a case-insensitive way in terms of key names. If a new option + * has the same key case-insensitively, it will override the existing option. * * @since 1.4.0 */ @@ -141,8 +134,8 @@ abstract class DataFrameWriter[T] { /** * Adds output options for the underlying data source. * - * All options are maintained in a case-insensitive way in terms of key names. - * If a new option has the same key case-insensitively, it will override the existing option. + * All options are maintained in a case-insensitive way in terms of key names. If a new option + * has the same key case-insensitively, it will override the existing option. * * @since 1.4.0 */ @@ -154,16 +147,13 @@ abstract class DataFrameWriter[T] { /** * Partitions the output by the given columns on the file system. If specified, the output is * laid out on the file system similar to Hive's partitioning scheme. As an example, when we - * partition a dataset by year and then month, the directory layout would look like: - *
    - *
  • year=2016/month=01/
  • - *
  • year=2016/month=02/
  • - *
+ * partition a dataset by year and then month, the directory layout would look like:
    + *
  • year=2016/month=01/
  • year=2016/month=02/
* - * Partitioning is one of the most widely used techniques to optimize physical data layout. - * It provides a coarse-grained index for skipping unnecessary data reads when queries have - * predicates on the partitioned columns. In order for partitioning to work well, the number - * of distinct values in each column should typically be less than tens of thousands. + * Partitioning is one of the most widely used techniques to optimize physical data layout. It + * provides a coarse-grained index for skipping unnecessary data reads when queries have + * predicates on the partitioned columns. In order for partitioning to work well, the number of + * distinct values in each column should typically be less than tens of thousands. * * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark * 2.1.0. @@ -179,8 +169,8 @@ abstract class DataFrameWriter[T] { /** * Buckets the output by the given columns. If specified, the output is laid out on the file - * system similar to Hive's bucketing scheme, but with a different bucket hash function - * and is not compatible with Hive's bucketing. + * system similar to Hive's bucketing scheme, but with a different bucket hash function and is + * not compatible with Hive's bucketing. * * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark * 2.1.0. @@ -241,13 +231,15 @@ abstract class DataFrameWriter[T] { def save(): Unit /** - * Inserts the content of the `DataFrame` to the specified table. It requires that - * the schema of the `DataFrame` is the same as the schema of the table. + * Inserts the content of the `DataFrame` to the specified table. It requires that the schema of + * the `DataFrame` is the same as the schema of the table. * - * @note Unlike `saveAsTable`, `insertInto` ignores the column names and just uses position-based - * resolution. For example: - * @note SaveMode.ErrorIfExists and SaveMode.Ignore behave as SaveMode.Append in `insertInto` as - * `insertInto` is not a table creating operation. + * @note + * Unlike `saveAsTable`, `insertInto` ignores the column names and just uses position-based + * resolution. For example: + * @note + * SaveMode.ErrorIfExists and SaveMode.Ignore behave as SaveMode.Append in `insertInto` as + * `insertInto` is not a table creating operation. * * {{{ * scala> Seq((1, 2)).toDF("i", "j").write.mode("overwrite").saveAsTable("t1") @@ -263,7 +255,7 @@ abstract class DataFrameWriter[T] { * +---+---+ * }}} * - * Because it inserts data to an existing table, format or options will be ignored. + * Because it inserts data to an existing table, format or options will be ignored. * @since 1.4.0 */ def insertInto(tableName: String): Unit @@ -271,15 +263,15 @@ abstract class DataFrameWriter[T] { /** * Saves the content of the `DataFrame` as the specified table. * - * In the case the table already exists, behavior of this function depends on the - * save mode, specified by the `mode` function (default to throwing an exception). - * When `mode` is `Overwrite`, the schema of the `DataFrame` does not need to be - * the same as that of the existing table. + * In the case the table already exists, behavior of this function depends on the save mode, + * specified by the `mode` function (default to throwing an exception). When `mode` is + * `Overwrite`, the schema of the `DataFrame` does not need to be the same as that of the + * existing table. * * When `mode` is `Append`, if there is an existing table, we will use the format and options of * the existing table. The column order in the schema of the `DataFrame` doesn't need to be same - * as that of the existing table. Unlike `insertInto`, `saveAsTable` will use the column names to - * find the correct column positions. For example: + * as that of the existing table. Unlike `insertInto`, `saveAsTable` will use the column names + * to find the correct column positions. For example: * * {{{ * scala> Seq((1, 2)).toDF("i", "j").write.mode("overwrite").saveAsTable("t1") @@ -293,10 +285,10 @@ abstract class DataFrameWriter[T] { * +---+---+ * }}} * - * In this method, save mode is used to determine the behavior if the data source table exists in - * Spark catalog. We will always overwrite the underlying data of data source (e.g. a table in - * JDBC data source) if the table doesn't exist in Spark catalog, and will always append to the - * underlying data of data source if the table already exists. + * In this method, save mode is used to determine the behavior if the data source table exists + * in Spark catalog. We will always overwrite the underlying data of data source (e.g. a table + * in JDBC data source) if the table doesn't exist in Spark catalog, and will always append to + * the underlying data of data source if the table already exists. * * When the DataFrame is created from a non-partitioned `HadoopFsRelation` with a single input * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC @@ -310,25 +302,25 @@ abstract class DataFrameWriter[T] { /** * Saves the content of the `DataFrame` to an external database table via JDBC. In the case the - * table already exists in the external database, behavior of this function depends on the - * save mode, specified by the `mode` function (default to throwing an exception). + * table already exists in the external database, behavior of this function depends on the save + * mode, specified by the `mode` function (default to throwing an exception). * * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash * your external database systems. * - * JDBC-specific option and parameter documentation for storing tables via JDBC in - *
- * Data Source Option in the version you use. - * - * @param table Name of the table in the external database. - * @param connectionProperties JDBC database connection arguments, a list of arbitrary string - * tag/value. Normally at least a "user" and "password" property - * should be included. "batchsize" can be used to control the - * number of rows per insert. "isolationLevel" can be one of - * "NONE", "READ_COMMITTED", "READ_UNCOMMITTED", "REPEATABLE_READ", - * or "SERIALIZABLE", corresponding to standard transaction - * isolation levels defined by JDBC's Connection object, with default - * of "READ_UNCOMMITTED". + * JDBC-specific option and parameter documentation for storing tables via JDBC in + * Data Source Option in the version you use. + * + * @param table + * Name of the table in the external database. + * @param connectionProperties + * JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least + * a "user" and "password" property should be included. "batchsize" can be used to control the + * number of rows per insert. "isolationLevel" can be one of "NONE", "READ_COMMITTED", + * "READ_UNCOMMITTED", "REPEATABLE_READ", or "SERIALIZABLE", corresponding to standard + * transaction isolation levels defined by JDBC's Connection object, with default of + * "READ_UNCOMMITTED". * @since 1.4.0 */ def jdbc(url: String, table: String, connectionProperties: util.Properties): Unit = { @@ -343,16 +335,16 @@ abstract class DataFrameWriter[T] { } /** - * Saves the content of the `DataFrame` in JSON format ( - * JSON Lines text format or newline-delimited JSON) at the specified path. - * This is equivalent to: + * Saves the content of the `DataFrame` in JSON format ( JSON + * Lines text format or newline-delimited JSON) at the specified path. This is equivalent + * to: * {{{ * format("json").save(path) * }}} * - * You can find the JSON-specific options for writing JSON files in - * - * Data Source Option in the version you use. + * You can find the JSON-specific options for writing JSON files in + * Data Source Option in the version you use. * * @since 1.4.0 */ @@ -361,16 +353,15 @@ abstract class DataFrameWriter[T] { } /** - * Saves the content of the `DataFrame` in Parquet format at the specified path. - * This is equivalent to: + * Saves the content of the `DataFrame` in Parquet format at the specified path. This is + * equivalent to: * {{{ * format("parquet").save(path) * }}} * - * Parquet-specific option(s) for writing Parquet files can be found in - * - * Data Source Option in the version you use. + * Parquet-specific option(s) for writing Parquet files can be found in Data + * Source Option in the version you use. * * @since 1.4.0 */ @@ -379,16 +370,15 @@ abstract class DataFrameWriter[T] { } /** - * Saves the content of the `DataFrame` in ORC format at the specified path. - * This is equivalent to: + * Saves the content of the `DataFrame` in ORC format at the specified path. This is equivalent + * to: * {{{ * format("orc").save(path) * }}} * - * ORC-specific option(s) for writing ORC files can be found in - * - * Data Source Option in the version you use. + * ORC-specific option(s) for writing ORC files can be found in Data + * Source Option in the version you use. * * @since 1.5.0 */ @@ -397,9 +387,9 @@ abstract class DataFrameWriter[T] { } /** - * Saves the content of the `DataFrame` in a text file at the specified path. - * The DataFrame must have only one column that is of string type. - * Each row becomes a new line in the output file. For example: + * Saves the content of the `DataFrame` in a text file at the specified path. The DataFrame must + * have only one column that is of string type. Each row becomes a new line in the output file. + * For example: * {{{ * // Scala: * df.write.text("/path/to/output") @@ -409,9 +399,9 @@ abstract class DataFrameWriter[T] { * }}} * The text files will be encoded as UTF-8. * - * You can find the text-specific options for writing text files in - * - * Data Source Option in the version you use. + * You can find the text-specific options for writing text files in + * Data Source Option in the version you use. * * @since 1.6.0 */ @@ -420,15 +410,15 @@ abstract class DataFrameWriter[T] { } /** - * Saves the content of the `DataFrame` in CSV format at the specified path. - * This is equivalent to: + * Saves the content of the `DataFrame` in CSV format at the specified path. This is equivalent + * to: * {{{ * format("csv").save(path) * }}} * - * You can find the CSV-specific options for writing CSV files in - * - * Data Source Option in the version you use. + * You can find the CSV-specific options for writing CSV files in + * Data Source Option in the version you use. * * @since 2.0.0 */ @@ -437,31 +427,25 @@ abstract class DataFrameWriter[T] { } /** - * Saves the content of the `DataFrame` in XML format at the specified path. - * This is equivalent to: + * Saves the content of the `DataFrame` in XML format at the specified path. This is equivalent + * to: * {{{ * format("xml").save(path) * }}} * - * Note that writing a XML file from `DataFrame` having a field `ArrayType` with - * its element as `ArrayType` would have an additional nested field for the element. - * For example, the `DataFrame` having a field below, + * Note that writing a XML file from `DataFrame` having a field `ArrayType` with its element as + * `ArrayType` would have an additional nested field for the element. For example, the + * `DataFrame` having a field below, * - * {@code fieldA [[data1], [data2]]} + * {@code fieldA [[data1], [data2]]} * - * would produce a XML file below. - * {@code - * - * data1 - * - * - * data2 - * } + * would produce a XML file below. {@code data1 + * data2 } * * Namely, roundtrip in writing and reading can end up in different schema structure. * - * You can find the XML-specific options for writing XML files in - * + * You can find the XML-specific options for writing XML files in * Data Source Option in the version you use. */ def xml(path: String): Unit = { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala similarity index 62% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala rename to sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index 3f9b224003914..37a29c2e4b66d 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -14,111 +14,80 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.sql -import scala.collection.mutable -import scala.jdk.CollectionConverters._ +import java.util import org.apache.spark.annotation.Experimental -import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} /** - * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 + * Interface used to write a [[org.apache.spark.sql.api.Dataset]] to external storage using the v2 * API. * - * @since 3.4.0 + * @since 3.0.0 */ @Experimental -final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T]) - extends CreateTableWriter[T] { - import ds.sparkSession.RichColumn - - private var provider: Option[String] = None - - private val options = new mutable.HashMap[String, String]() +abstract class DataFrameWriterV2[T] extends CreateTableWriter[T] { - private val properties = new mutable.HashMap[String, String]() + /** @inheritdoc */ + override def using(provider: String): this.type - private var partitioning: Option[Seq[proto.Expression]] = None + /** @inheritdoc */ + override def option(key: String, value: Boolean): this.type = option(key, value.toString) - private var clustering: Option[Seq[String]] = None + /** @inheritdoc */ + override def option(key: String, value: Long): this.type = option(key, value.toString) - private var overwriteCondition: Option[proto.Expression] = None + /** @inheritdoc */ + override def option(key: String, value: Double): this.type = option(key, value.toString) - override def using(provider: String): CreateTableWriter[T] = { - this.provider = Some(provider) - this - } + /** @inheritdoc */ + override def option(key: String, value: String): this.type - override def option(key: String, value: String): DataFrameWriterV2[T] = { - this.options.put(key, value) - this - } + /** @inheritdoc */ + override def options(options: scala.collection.Map[String, String]): this.type - override def options(options: scala.collection.Map[String, String]): DataFrameWriterV2[T] = { - options.foreach { case (key, value) => - this.options.put(key, value) - } - this - } + /** @inheritdoc */ + override def options(options: util.Map[String, String]): this.type - override def options(options: java.util.Map[String, String]): DataFrameWriterV2[T] = { - this.options(options.asScala) - this - } - - override def tableProperty(property: String, value: String): CreateTableWriter[T] = { - this.properties.put(property, value) - this - } + /** @inheritdoc */ + override def tableProperty(property: String, value: String): this.type + /** @inheritdoc */ @scala.annotation.varargs - override def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] = { - val asTransforms = (column +: columns).map(_.expr) - this.partitioning = Some(asTransforms) - this - } + override def partitionedBy(column: Column, columns: Column*): this.type + /** @inheritdoc */ @scala.annotation.varargs - override def clusterBy(colName: String, colNames: String*): CreateTableWriter[T] = { - this.clustering = Some(colName +: colNames) - this - } - - override def create(): Unit = { - executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE) - } - - override def replace(): Unit = { - executeWriteOperation(proto.WriteOperationV2.Mode.MODE_REPLACE) - } - - override def createOrReplace(): Unit = { - executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE) - } + override def clusterBy(colName: String, colNames: String*): this.type /** * Append the contents of the data frame to the output table. * - * If the output table does not exist, this operation will fail. The data frame will be + * If the output table does not exist, this operation will fail with + * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be * validated to ensure it is compatible with the existing table. + * + * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException + * If the table does not exist */ - def append(): Unit = { - executeWriteOperation(proto.WriteOperationV2.Mode.MODE_APPEND) - } + @throws(classOf[NoSuchTableException]) + def append(): Unit /** * Overwrite rows matching the given filter condition with the contents of the data frame in the * output table. * - * If the output table does not exist, this operation will fail. The data frame will be + * If the output table does not exist, this operation will fail with + * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be * validated to ensure it is compatible with the existing table. + * + * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException + * If the table does not exist */ - def overwrite(condition: Column): Unit = { - overwriteCondition = Some(condition.expr) - executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE) - } + @throws(classOf[NoSuchTableException]) + def overwrite(condition: Column): Unit /** * Overwrite all partition for which the data frame contains at least one row with the contents @@ -127,85 +96,64 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T]) * This operation is equivalent to Hive's `INSERT OVERWRITE ... PARTITION`, which replaces * partitions dynamically depending on the contents of the data frame. * - * If the output table does not exist, this operation will fail. The data frame will be + * If the output table does not exist, this operation will fail with + * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be * validated to ensure it is compatible with the existing table. + * + * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException + * If the table does not exist */ - def overwritePartitions(): Unit = { - executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS) - } - - private def executeWriteOperation(mode: proto.WriteOperationV2.Mode): Unit = { - val builder = proto.WriteOperationV2.newBuilder() - - builder.setInput(ds.plan.getRoot) - builder.setTableName(table) - provider.foreach(builder.setProvider) - - partitioning.foreach(columns => builder.addAllPartitioningColumns(columns.asJava)) - clustering.foreach(columns => builder.addAllClusteringColumns(columns.asJava)) - - options.foreach { case (k, v) => - builder.putOptions(k, v) - } - properties.foreach { case (k, v) => - builder.putTableProperties(k, v) - } - - builder.setMode(mode) - - overwriteCondition.foreach(builder.setOverwriteCondition) - - ds.sparkSession.execute(proto.Command.newBuilder().setWriteOperationV2(builder).build()) - } + @throws(classOf[NoSuchTableException]) + def overwritePartitions(): Unit } /** * Configuration methods common to create/replace operations and insert/overwrite operations. * @tparam R * builder type to return - * @since 3.4.0 + * @since 3.0.0 */ trait WriteConfigMethods[R] { /** * Add a write option. * - * @since 3.4.0 + * @since 3.0.0 */ def option(key: String, value: String): R /** * Add a boolean output option. * - * @since 3.4.0 + * @since 3.0.0 */ def option(key: String, value: Boolean): R = option(key, value.toString) /** * Add a long output option. * - * @since 3.4.0 + * @since 3.0.0 */ def option(key: String, value: Long): R = option(key, value.toString) /** * Add a double output option. * - * @since 3.4.0 + * @since 3.0.0 */ def option(key: String, value: Double): R = option(key, value.toString) /** * Add write options from a Scala Map. * - * @since 3.4.0 + * @since 3.0.0 */ def options(options: scala.collection.Map[String, String]): R /** * Add write options from a Java Map. * - * @since 3.4.0 + * @since 3.0.0 */ def options(options: java.util.Map[String, String]): R } @@ -213,7 +161,7 @@ trait WriteConfigMethods[R] { /** * Trait to restrict calls to create and replace operations. * - * @since 3.4.0 + * @since 3.0.0 */ trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] { @@ -223,8 +171,13 @@ trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] { * The new table's schema, partition layout, properties, and other configuration will be based * on the configuration set on this writer. * - * If the output table exists, this operation will fail. + * If the output table exists, this operation will fail with + * [[org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException]]. + * + * @throws org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException + * If the table already exists */ + @throws(classOf[TableAlreadyExistsException]) def create(): Unit /** @@ -233,8 +186,13 @@ trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] { * The existing table's schema, partition layout, properties, and other configuration will be * replaced with the contents of the data frame and the configuration set on this writer. * - * If the output table does not exist, this operation will fail. + * If the output table does not exist, this operation will fail with + * [[org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException]]. + * + * @throws org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException + * If the table does not exist */ + @throws(classOf[CannotReplaceMissingTableException]) def replace(): Unit /** @@ -260,7 +218,7 @@ trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] { * predicates on the partitioned columns. In order for partitioning to work well, the number of * distinct values in each column should typically be less than tens of thousands. * - * @since 3.4.0 + * @since 3.0.0 */ @scala.annotation.varargs def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] @@ -282,7 +240,7 @@ trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] { * Specifies a provider for the underlying output data source. Spark's default catalog supports * "parquet", "json", etc. * - * @since 3.4.0 + * @since 3.0.0 */ def using(provider: String): CreateTableWriter[T] diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/Encoder.scala index ea760d80541c8..d125e89b8c410 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types._ /** * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. * - * == Scala == + * ==Scala== * Encoders are generally created automatically through implicits from a `SparkSession`, or can be * explicitly created by calling static methods on [[Encoders]]. * @@ -35,7 +35,7 @@ import org.apache.spark.sql.types._ * val ds = Seq(1, 2, 3).toDS() // implicitly provided (spark.implicits.newIntEncoder) * }}} * - * == Java == + * ==Java== * Encoders are specified by calling static methods on [[Encoders]]. * * {{{ @@ -57,8 +57,8 @@ import org.apache.spark.sql.types._ * Encoders.bean(MyClass.class); * }}} * - * == Implementation == - * - Encoders should be thread-safe. + * ==Implementation== + * - Encoders should be thread-safe. * * @since 1.6.0 */ diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala similarity index 72% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala rename to sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala index ffd9975770066..9976b34f7a01f 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -14,95 +14,99 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql -import scala.reflect.ClassTag +import java.lang.reflect.Modifier + +import scala.reflect.{classTag, ClassTag} import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, JavaSerializationCodec, RowEncoder => RowEncoderFactory} +import org.apache.spark.sql.catalyst.encoders.{Codec, JavaSerializationCodec, KryoSerializationCodec, RowEncoder => SchemaInference} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.errors.ExecutionErrors +import org.apache.spark.sql.types._ /** * Methods for creating an [[Encoder]]. * - * @since 3.5.0 + * @since 1.6.0 */ object Encoders { /** * An encoder for nullable boolean type. The Scala primitive encoder is available as * [[scalaBoolean]]. - * @since 3.5.0 + * @since 1.6.0 */ def BOOLEAN: Encoder[java.lang.Boolean] = BoxedBooleanEncoder /** * An encoder for nullable byte type. The Scala primitive encoder is available as [[scalaByte]]. - * @since 3.5.0 + * @since 1.6.0 */ def BYTE: Encoder[java.lang.Byte] = BoxedByteEncoder /** * An encoder for nullable short type. The Scala primitive encoder is available as * [[scalaShort]]. - * @since 3.5.0 + * @since 1.6.0 */ def SHORT: Encoder[java.lang.Short] = BoxedShortEncoder /** * An encoder for nullable int type. The Scala primitive encoder is available as [[scalaInt]]. - * @since 3.5.0 + * @since 1.6.0 */ def INT: Encoder[java.lang.Integer] = BoxedIntEncoder /** * An encoder for nullable long type. The Scala primitive encoder is available as [[scalaLong]]. - * @since 3.5.0 + * @since 1.6.0 */ def LONG: Encoder[java.lang.Long] = BoxedLongEncoder /** * An encoder for nullable float type. The Scala primitive encoder is available as * [[scalaFloat]]. - * @since 3.5.0 + * @since 1.6.0 */ def FLOAT: Encoder[java.lang.Float] = BoxedFloatEncoder /** * An encoder for nullable double type. The Scala primitive encoder is available as * [[scalaDouble]]. - * @since 3.5.0 + * @since 1.6.0 */ def DOUBLE: Encoder[java.lang.Double] = BoxedDoubleEncoder /** * An encoder for nullable string type. * - * @since 3.5.0 + * @since 1.6.0 */ def STRING: Encoder[java.lang.String] = StringEncoder /** * An encoder for nullable decimal type. * - * @since 3.5.0 + * @since 1.6.0 */ def DECIMAL: Encoder[java.math.BigDecimal] = DEFAULT_JAVA_DECIMAL_ENCODER /** * An encoder for nullable date type. * - * @since 3.5.0 + * @since 1.6.0 */ - def DATE: Encoder[java.sql.Date] = DateEncoder(lenientSerialization = false) + def DATE: Encoder[java.sql.Date] = STRICT_DATE_ENCODER /** * Creates an encoder that serializes instances of the `java.time.LocalDate` class to the * internal representation of nullable Catalyst's DateType. * - * @since 3.5.0 + * @since 3.0.0 */ def LOCALDATE: Encoder[java.time.LocalDate] = STRICT_LOCAL_DATE_ENCODER @@ -110,14 +114,14 @@ object Encoders { * Creates an encoder that serializes instances of the `java.time.LocalDateTime` class to the * internal representation of nullable Catalyst's TimestampNTZType. * - * @since 3.5.0 + * @since 3.4.0 */ def LOCALDATETIME: Encoder[java.time.LocalDateTime] = LocalDateTimeEncoder /** * An encoder for nullable timestamp type. * - * @since 3.5.0 + * @since 1.6.0 */ def TIMESTAMP: Encoder[java.sql.Timestamp] = STRICT_TIMESTAMP_ENCODER @@ -125,14 +129,14 @@ object Encoders { * Creates an encoder that serializes instances of the `java.time.Instant` class to the internal * representation of nullable Catalyst's TimestampType. * - * @since 3.5.0 + * @since 3.0.0 */ def INSTANT: Encoder[java.time.Instant] = STRICT_INSTANT_ENCODER /** * An encoder for arrays of bytes. * - * @since 3.5.0 + * @since 1.6.1 */ def BINARY: Encoder[Array[Byte]] = BinaryEncoder @@ -140,7 +144,7 @@ object Encoders { * Creates an encoder that serializes instances of the `java.time.Duration` class to the * internal representation of nullable Catalyst's DayTimeIntervalType. * - * @since 3.5.0 + * @since 3.2.0 */ def DURATION: Encoder[java.time.Duration] = DayTimeIntervalEncoder @@ -148,7 +152,7 @@ object Encoders { * Creates an encoder that serializes instances of the `java.time.Period` class to the internal * representation of nullable Catalyst's YearMonthIntervalType. * - * @since 3.5.0 + * @since 3.2.0 */ def PERIOD: Encoder[java.time.Period] = YearMonthIntervalEncoder @@ -166,7 +170,7 @@ object Encoders { * - collection types: array, java.util.List, and map * - nested java bean. * - * @since 3.5.0 + * @since 1.6.0 */ def bean[T](beanClass: Class[T]): Encoder[T] = JavaTypeInference.encoderFor(beanClass) @@ -175,7 +179,27 @@ object Encoders { * * @since 3.5.0 */ - def row(schema: StructType): Encoder[Row] = RowEncoderFactory.encoderFor(schema) + def row(schema: StructType): Encoder[Row] = SchemaInference.encoderFor(schema) + + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. This + * encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def kryo[T: ClassTag]: Encoder[T] = genericSerializer(KryoSerializationCodec) + + /** + * Creates an encoder that serializes objects of type T using Kryo. This encoder maps T into a + * single byte array (binary) field. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) /** * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java @@ -185,11 +209,10 @@ object Encoders { * * @note * This is extremely inefficient and should only be used as the last resort. - * @since 4.0.0 + * + * @since 1.6.0 */ - def javaSerialization[T: ClassTag]: Encoder[T] = { - TransformingEncoder(implicitly[ClassTag[T]], BinaryEncoder, JavaSerializationCodec) - } + def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(JavaSerializationCodec) /** * Creates an encoder that serializes objects of type T using generic Java serialization. This @@ -199,25 +222,53 @@ object Encoders { * * @note * This is extremely inefficient and should only be used as the last resort. - * @since 4.0.0 + * + * @since 1.6.0 */ - def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) + def javaSerialization[T](clazz: Class[T]): Encoder[T] = + javaSerialization(ClassTag[T](clazz)) + + /** Throws an exception if T is not a public class. */ + private def validatePublicClass[T: ClassTag](): Unit = { + if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) { + throw ExecutionErrors.notPublicClassError(classTag[T].runtimeClass.getName) + } + } + + /** A way to construct encoders using generic serializers. */ + private def genericSerializer[T: ClassTag]( + provider: () => Codec[Any, Array[Byte]]): Encoder[T] = { + if (classTag[T].runtimeClass.isPrimitive) { + throw ExecutionErrors.primitiveTypesNotSupportedError() + } + + validatePublicClass[T]() + + TransformingEncoder(classTag[T], BinaryEncoder, provider) + } - private def tupleEncoder[T](encoders: Encoder[_]*): Encoder[T] = { - ProductEncoder.tuple(encoders.asInstanceOf[Seq[AgnosticEncoder[_]]]).asInstanceOf[Encoder[T]] + private[sql] def tupleEncoder[T](encoders: Encoder[_]*): Encoder[T] = { + ProductEncoder.tuple(encoders.map(agnosticEncoderFor(_))).asInstanceOf[Encoder[T]] } + /** + * An encoder for 1-ary tuples. + * + * @since 4.0.0 + */ + def tuple[T1](e1: Encoder[T1]): Encoder[(T1)] = tupleEncoder(e1) + /** * An encoder for 2-ary tuples. * - * @since 3.5.0 + * @since 1.6.0 */ def tuple[T1, T2](e1: Encoder[T1], e2: Encoder[T2]): Encoder[(T1, T2)] = tupleEncoder(e1, e2) /** * An encoder for 3-ary tuples. * - * @since 3.5.0 + * @since 1.6.0 */ def tuple[T1, T2, T3]( e1: Encoder[T1], @@ -227,7 +278,7 @@ object Encoders { /** * An encoder for 4-ary tuples. * - * @since 3.5.0 + * @since 1.6.0 */ def tuple[T1, T2, T3, T4]( e1: Encoder[T1], @@ -238,7 +289,7 @@ object Encoders { /** * An encoder for 5-ary tuples. * - * @since 3.5.0 + * @since 1.6.0 */ def tuple[T1, T2, T3, T4, T5]( e1: Encoder[T1], @@ -249,49 +300,50 @@ object Encoders { /** * An encoder for Scala's product type (tuples, case classes, etc). - * @since 3.5.0 + * @since 2.0.0 */ def product[T <: Product: TypeTag]: Encoder[T] = ScalaReflection.encoderFor[T] /** * An encoder for Scala's primitive int type. - * @since 3.5.0 + * @since 2.0.0 */ def scalaInt: Encoder[Int] = PrimitiveIntEncoder /** * An encoder for Scala's primitive long type. - * @since 3.5.0 + * @since 2.0.0 */ def scalaLong: Encoder[Long] = PrimitiveLongEncoder /** * An encoder for Scala's primitive double type. - * @since 3.5.0 + * @since 2.0.0 */ def scalaDouble: Encoder[Double] = PrimitiveDoubleEncoder /** * An encoder for Scala's primitive float type. - * @since 3.5.0 + * @since 2.0.0 */ def scalaFloat: Encoder[Float] = PrimitiveFloatEncoder /** * An encoder for Scala's primitive byte type. - * @since 3.5.0 + * @since 2.0.0 */ def scalaByte: Encoder[Byte] = PrimitiveByteEncoder /** * An encoder for Scala's primitive short type. - * @since 3.5.0 + * @since 2.0.0 */ def scalaShort: Encoder[Short] = PrimitiveShortEncoder /** * An encoder for Scala's primitive boolean type. - * @since 3.5.0 + * @since 2.0.0 */ def scalaBoolean: Encoder[Boolean] = PrimitiveBooleanEncoder + } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/api/src/main/scala/org/apache/spark/sql/ForeachWriter.scala similarity index 99% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/ForeachWriter.scala rename to sql/api/src/main/scala/org/apache/spark/sql/ForeachWriter.scala index 756356f7f0282..4e2bb35146a63 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/ForeachWriter.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -89,7 +89,7 @@ package org.apache.spark.sql * }); * }}} * - * @since 3.5.0 + * @since 2.0.0 */ abstract class ForeachWriter[T] extends Serializable { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala b/sql/api/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala similarity index 65% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala rename to sql/api/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala index 71813af1e354f..db56b39e28aeb 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala @@ -14,46 +14,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.sql -import scala.jdk.CollectionConverters._ - -import org.apache.spark.SparkRuntimeException import org.apache.spark.annotation.Experimental -import org.apache.spark.connect.proto -import org.apache.spark.connect.proto.{Expression, MergeIntoTableCommand} -import org.apache.spark.connect.proto.MergeAction -import org.apache.spark.sql.functions.expr /** * `MergeIntoWriter` provides methods to define and execute merge actions based on specified * conditions. * + * Please note that schema evolution is disabled by default. + * * @tparam T * the type of data in the Dataset. - * @param table - * the name of the target table for the merge operation. - * @param ds - * the source Dataset to merge into the target table. - * @param on - * the merge condition. - * @param schemaEvolutionEnabled - * whether to enable automatic schema evolution for this merge operation. Default is `false`. - * * @since 4.0.0 */ @Experimental -class MergeIntoWriter[T] private[sql] ( - table: String, - ds: Dataset[T], - on: Column, - schemaEvolutionEnabled: Boolean = false) { - import ds.sparkSession.RichColumn +abstract class MergeIntoWriter[T] { + private var schemaEvolution: Boolean = false - private[sql] var matchedActions: Seq[MergeAction] = Seq.empty[MergeAction] - private[sql] var notMatchedActions: Seq[MergeAction] = Seq.empty[MergeAction] - private[sql] var notMatchedBySourceActions: Seq[MergeAction] = Seq.empty[MergeAction] + private[sql] def schemaEvolutionEnabled: Boolean = schemaEvolution /** * Initialize a `WhenMatched` action without any condition. @@ -176,84 +155,39 @@ class MergeIntoWriter[T] private[sql] ( /** * Enable automatic schema evolution for this merge operation. + * * @return * A `MergeIntoWriter` instance with schema evolution enabled. */ def withSchemaEvolution(): MergeIntoWriter[T] = { - new MergeIntoWriter[T](this.table, this.ds, this.on, schemaEvolutionEnabled = true) - .withNewMatchedActions(this.matchedActions: _*) - .withNewNotMatchedActions(this.notMatchedActions: _*) - .withNewNotMatchedBySourceActions(this.notMatchedBySourceActions: _*) + schemaEvolution = true + this } /** * Executes the merge operation. */ - def merge(): Unit = { - if (matchedActions.isEmpty && notMatchedActions.isEmpty && notMatchedBySourceActions.isEmpty) { - throw new SparkRuntimeException( - errorClass = "NO_MERGE_ACTION_SPECIFIED", - messageParameters = Map.empty) - } + def merge(): Unit - val matchedActionExpressions = - matchedActions.map(Expression.newBuilder().setMergeAction(_)).map(_.build()) - val notMatchedActionExpressions = - notMatchedActions.map(Expression.newBuilder().setMergeAction(_)).map(_.build()) - val notMatchedBySourceActionExpressions = - notMatchedBySourceActions.map(Expression.newBuilder().setMergeAction(_)).map(_.build()) - val mergeIntoCommand = MergeIntoTableCommand - .newBuilder() - .setTargetTableName(table) - .setSourceTablePlan(ds.plan.getRoot) - .setMergeCondition(on.expr) - .addAllMatchActions(matchedActionExpressions.asJava) - .addAllNotMatchedActions(notMatchedActionExpressions.asJava) - .addAllNotMatchedBySourceActions(notMatchedBySourceActionExpressions.asJava) - .setWithSchemaEvolution(schemaEvolutionEnabled) - .build() + // Action callbacks. + protected[sql] def insertAll(condition: Option[Column]): MergeIntoWriter[T] - ds.sparkSession.execute( - proto.Command - .newBuilder() - .setMergeIntoTableCommand(mergeIntoCommand) - .build()) - } + protected[sql] def insert( + condition: Option[Column], + map: Map[String, Column]): MergeIntoWriter[T] - private[sql] def withNewMatchedActions(action: MergeAction*): MergeIntoWriter[T] = { - this.matchedActions = this.matchedActions :++ action - this - } + protected[sql] def updateAll( + condition: Option[Column], + notMatchedBySource: Boolean): MergeIntoWriter[T] - private[sql] def withNewNotMatchedActions(action: MergeAction*): MergeIntoWriter[T] = { - this.notMatchedActions = this.notMatchedActions :++ action - this - } + protected[sql] def update( + condition: Option[Column], + map: Map[String, Column], + notMatchedBySource: Boolean): MergeIntoWriter[T] - private[sql] def withNewNotMatchedBySourceActions(action: MergeAction*): MergeIntoWriter[T] = { - this.notMatchedBySourceActions = this.notMatchedBySourceActions :++ action - this - } - - private[sql] def buildMergeAction( - actionType: MergeAction.ActionType, - conditionOpt: Option[Column], - assignmentsOpt: Option[Map[String, Column]] = None): MergeAction = { - val assignmentsProtoOpt = assignmentsOpt.map { - _.map { case (k, v) => - MergeAction.Assignment - .newBuilder() - .setKey(expr(k).expr) - .setValue(v.expr) - .build() - }.toSeq.asJava - } - - val builder = MergeAction.newBuilder().setActionType(actionType) - conditionOpt.map(c => builder.setCondition(c.expr)) - assignmentsProtoOpt.map(builder.addAllAssignments) - builder.build() - } + protected[sql] def delete( + condition: Option[Column], + notMatchedBySource: Boolean): MergeIntoWriter[T] } /** @@ -265,7 +199,6 @@ class MergeIntoWriter[T] private[sql] ( * @param condition * An optional condition Expression that specifies when the actions should be applied. If the * condition is None, the actions will be applied to all matched rows. - * * @tparam T * The type of data in the MergeIntoWriter. */ @@ -279,11 +212,8 @@ case class WhenMatched[T] private[sql] ( * @return * The MergeIntoWriter instance with the update all action configured. */ - def updateAll(): MergeIntoWriter[T] = { - mergeIntoWriter.withNewMatchedActions( - mergeIntoWriter - .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_UPDATE_STAR, condition)) - } + def updateAll(): MergeIntoWriter[T] = + mergeIntoWriter.updateAll(condition, notMatchedBySource = false) /** * Specifies an action to update matched rows in the DataFrame with the provided column @@ -294,11 +224,8 @@ case class WhenMatched[T] private[sql] ( * @return * The MergeIntoWriter instance with the update action configured. */ - def update(map: Map[String, Column]): MergeIntoWriter[T] = { - mergeIntoWriter.withNewMatchedActions( - mergeIntoWriter - .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_UPDATE, condition, Some(map))) - } + def update(map: Map[String, Column]): MergeIntoWriter[T] = + mergeIntoWriter.update(condition, map, notMatchedBySource = false) /** * Specifies an action to delete matched rows from the DataFrame. @@ -306,10 +233,8 @@ case class WhenMatched[T] private[sql] ( * @return * The MergeIntoWriter instance with the delete action configured. */ - def delete(): MergeIntoWriter[T] = { - mergeIntoWriter.withNewMatchedActions( - mergeIntoWriter.buildMergeAction(MergeAction.ActionType.ACTION_TYPE_DELETE, condition)) - } + def delete(): MergeIntoWriter[T] = + mergeIntoWriter.delete(condition, notMatchedBySource = false) } /** @@ -335,11 +260,8 @@ case class WhenNotMatched[T] private[sql] ( * @return * The MergeIntoWriter instance with the insert all action configured. */ - def insertAll(): MergeIntoWriter[T] = { - mergeIntoWriter.withNewNotMatchedActions( - mergeIntoWriter - .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_INSERT_STAR, condition)) - } + def insertAll(): MergeIntoWriter[T] = + mergeIntoWriter.insertAll(condition) /** * Specifies an action to insert non-matched rows into the DataFrame with the provided column @@ -350,11 +272,8 @@ case class WhenNotMatched[T] private[sql] ( * @return * The MergeIntoWriter instance with the insert action configured. */ - def insert(map: Map[String, Column]): MergeIntoWriter[T] = { - mergeIntoWriter.withNewNotMatchedActions( - mergeIntoWriter - .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_INSERT, condition, Some(map))) - } + def insert(map: Map[String, Column]): MergeIntoWriter[T] = + mergeIntoWriter.insert(condition, map) } /** @@ -379,11 +298,8 @@ case class WhenNotMatchedBySource[T] private[sql] ( * @return * The MergeIntoWriter instance with the update all action configured. */ - def updateAll(): MergeIntoWriter[T] = { - mergeIntoWriter.withNewNotMatchedBySourceActions( - mergeIntoWriter - .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_UPDATE_STAR, condition)) - } + def updateAll(): MergeIntoWriter[T] = + mergeIntoWriter.updateAll(condition, notMatchedBySource = true) /** * Specifies an action to update non-matched rows in the target DataFrame with the provided @@ -394,11 +310,8 @@ case class WhenNotMatchedBySource[T] private[sql] ( * @return * The MergeIntoWriter instance with the update action configured. */ - def update(map: Map[String, Column]): MergeIntoWriter[T] = { - mergeIntoWriter.withNewNotMatchedBySourceActions( - mergeIntoWriter - .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_UPDATE, condition, Some(map))) - } + def update(map: Map[String, Column]): MergeIntoWriter[T] = + mergeIntoWriter.update(condition, map, notMatchedBySource = true) /** * Specifies an action to delete non-matched rows from the target DataFrame when not matched by @@ -407,9 +320,6 @@ case class WhenNotMatchedBySource[T] private[sql] ( * @return * The MergeIntoWriter instance with the delete action configured. */ - def delete(): MergeIntoWriter[T] = { - mergeIntoWriter.withNewNotMatchedBySourceActions( - mergeIntoWriter - .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_DELETE, condition)) - } + def delete(): MergeIntoWriter[T] = + mergeIntoWriter.delete(condition, notMatchedBySource = true) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala index 02f5a8de1e3f6..fa427fe651907 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala @@ -38,13 +38,14 @@ import org.apache.spark.util.SparkThreadUtils * val metrics = observation.get * }}} * - * This collects the metrics while the first action is executed on the observed dataset. Subsequent - * actions do not modify the metrics returned by [[get]]. Retrieval of the metric via [[get]] - * blocks until the first action has finished and metrics become available. + * This collects the metrics while the first action is executed on the observed dataset. + * Subsequent actions do not modify the metrics returned by [[get]]. Retrieval of the metric via + * [[get]] blocks until the first action has finished and metrics become available. * * This class does not support streaming datasets. * - * @param name name of the metric + * @param name + * name of the metric * @since 3.3.0 */ class Observation(val name: String) { @@ -65,23 +66,27 @@ class Observation(val name: String) { val future: Future[Map[String, Any]] = promise.future /** - * (Scala-specific) Get the observed metrics. This waits for the observed dataset to finish - * its first action. Only the result of the first action is available. Subsequent actions do not + * (Scala-specific) Get the observed metrics. This waits for the observed dataset to finish its + * first action. Only the result of the first action is available. Subsequent actions do not * modify the result. * - * @return the observed metrics as a `Map[String, Any]` - * @throws InterruptedException interrupted while waiting + * @return + * the observed metrics as a `Map[String, Any]` + * @throws InterruptedException + * interrupted while waiting */ @throws[InterruptedException] def get: Map[String, Any] = SparkThreadUtils.awaitResult(future, Duration.Inf) /** - * (Java-specific) Get the observed metrics. This waits for the observed dataset to finish - * its first action. Only the result of the first action is available. Subsequent actions do not + * (Java-specific) Get the observed metrics. This waits for the observed dataset to finish its + * first action. Only the result of the first action is available. Subsequent actions do not * modify the result. * - * @return the observed metrics as a `java.util.Map[String, Object]` - * @throws InterruptedException interrupted while waiting + * @return + * the observed metrics as a `java.util.Map[String, Object]` + * @throws InterruptedException + * interrupted while waiting */ @throws[InterruptedException] def getAsJava: java.util.Map[String, Any] = get.asJava @@ -89,7 +94,8 @@ class Observation(val name: String) { /** * Get the observed metrics. This returns the metrics if they are available, otherwise an empty. * - * @return the observed metrics as a `Map[String, Any]` + * @return + * the observed metrics as a `Map[String, Any]` */ @throws[InterruptedException] private[sql] def getOrEmpty: Map[String, Any] = { @@ -108,7 +114,8 @@ class Observation(val name: String) { /** * Set the observed metrics and notify all waiting threads to resume. * - * @return `true` if all waiting threads were notified, `false` if otherwise. + * @return + * `true` if all waiting threads were notified, `false` if otherwise. */ private[sql] def setMetricsAndNotify(metrics: Row): Boolean = { val metricsMap = metrics.getValuesMap(metrics.schema.map(_.name)) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Row.scala b/sql/api/src/main/scala/org/apache/spark/sql/Row.scala index fb4b4a6f37c8d..aa14115453aea 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/Row.scala @@ -45,6 +45,7 @@ import org.apache.spark.util.ArrayImplicits._ */ @Stable object Row { + /** * This method can be used to extract fields from a [[Row]] object in a pattern match. Example: * {{{ @@ -83,9 +84,8 @@ object Row { val empty = apply() } - /** - * Represents one row of output from a relational operator. Allows both generic access by ordinal, + * Represents one row of output from a relational operator. Allows both generic access by ordinal, * which will incur boxing overhead for primitives, as well as native primitive access. * * It is invalid to use the native primitive interface to retrieve a value that is null, instead a @@ -103,9 +103,9 @@ object Row { * Row.fromSeq(Seq(value1, value2, ...)) * }}} * - * A value of a row can be accessed through both generic access by ordinal, - * which will incur boxing overhead for primitives, as well as native primitive access. - * An example of generic access by ordinal: + * A value of a row can be accessed through both generic access by ordinal, which will incur + * boxing overhead for primitives, as well as native primitive access. An example of generic + * access by ordinal: * {{{ * import org.apache.spark.sql._ * @@ -117,10 +117,9 @@ object Row { * // fourthValue: Any = null * }}} * - * For native primitive access, it is invalid to use the native primitive interface to retrieve - * a value that is null, instead a user must check `isNullAt` before attempting to retrieve a - * value that might be null. - * An example of native primitive access: + * For native primitive access, it is invalid to use the native primitive interface to retrieve a + * value that is null, instead a user must check `isNullAt` before attempting to retrieve a value + * that might be null. An example of native primitive access: * {{{ * // using the row from the previous example. * val firstValue = row.getInt(0) @@ -143,6 +142,7 @@ object Row { */ @Stable trait Row extends Serializable { + /** Number of elements in the Row. */ def size: Int = length @@ -155,8 +155,8 @@ trait Row extends Serializable { def schema: StructType = null /** - * Returns the value at position i. If the value is null, null is returned. The following - * is a mapping between Spark SQL types and return types: + * Returns the value at position i. If the value is null, null is returned. The following is a + * mapping between Spark SQL types and return types: * * {{{ * BooleanType -> java.lang.Boolean @@ -184,8 +184,8 @@ trait Row extends Serializable { def apply(i: Int): Any = get(i) /** - * Returns the value at position i. If the value is null, null is returned. The following - * is a mapping between Spark SQL types and return types: + * Returns the value at position i. If the value is null, null is returned. The following is a + * mapping between Spark SQL types and return types: * * {{{ * BooleanType -> java.lang.Boolean @@ -218,106 +218,127 @@ trait Row extends Serializable { /** * Returns the value at position i as a primitive boolean. * - * @throws ClassCastException when data type does not match. - * @throws org.apache.spark.SparkRuntimeException when value is null. + * @throws ClassCastException + * when data type does not match. + * @throws org.apache.spark.SparkRuntimeException + * when value is null. */ def getBoolean(i: Int): Boolean = getAnyValAs[Boolean](i) /** * Returns the value at position i as a primitive byte. * - * @throws ClassCastException when data type does not match. - * @throws org.apache.spark.SparkRuntimeException when value is null. + * @throws ClassCastException + * when data type does not match. + * @throws org.apache.spark.SparkRuntimeException + * when value is null. */ def getByte(i: Int): Byte = getAnyValAs[Byte](i) /** * Returns the value at position i as a primitive short. * - * @throws ClassCastException when data type does not match. - * @throws org.apache.spark.SparkRuntimeException when value is null. + * @throws ClassCastException + * when data type does not match. + * @throws org.apache.spark.SparkRuntimeException + * when value is null. */ def getShort(i: Int): Short = getAnyValAs[Short](i) /** * Returns the value at position i as a primitive int. * - * @throws ClassCastException when data type does not match. - * @throws org.apache.spark.SparkRuntimeException when value is null. + * @throws ClassCastException + * when data type does not match. + * @throws org.apache.spark.SparkRuntimeException + * when value is null. */ def getInt(i: Int): Int = getAnyValAs[Int](i) /** * Returns the value at position i as a primitive long. * - * @throws ClassCastException when data type does not match. - * @throws org.apache.spark.SparkRuntimeException when value is null. + * @throws ClassCastException + * when data type does not match. + * @throws org.apache.spark.SparkRuntimeException + * when value is null. */ def getLong(i: Int): Long = getAnyValAs[Long](i) /** - * Returns the value at position i as a primitive float. - * Throws an exception if the type mismatches or if the value is null. + * Returns the value at position i as a primitive float. Throws an exception if the type + * mismatches or if the value is null. * - * @throws ClassCastException when data type does not match. - * @throws org.apache.spark.SparkRuntimeException when value is null. + * @throws ClassCastException + * when data type does not match. + * @throws org.apache.spark.SparkRuntimeException + * when value is null. */ def getFloat(i: Int): Float = getAnyValAs[Float](i) /** * Returns the value at position i as a primitive double. * - * @throws ClassCastException when data type does not match. - * @throws org.apache.spark.SparkRuntimeException when value is null. + * @throws ClassCastException + * when data type does not match. + * @throws org.apache.spark.SparkRuntimeException + * when value is null. */ def getDouble(i: Int): Double = getAnyValAs[Double](i) /** * Returns the value at position i as a String object. * - * @throws ClassCastException when data type does not match. + * @throws ClassCastException + * when data type does not match. */ def getString(i: Int): String = getAs[String](i) /** * Returns the value at position i of decimal type as java.math.BigDecimal. * - * @throws ClassCastException when data type does not match. + * @throws ClassCastException + * when data type does not match. */ def getDecimal(i: Int): java.math.BigDecimal = getAs[java.math.BigDecimal](i) /** * Returns the value at position i of date type as java.sql.Date. * - * @throws ClassCastException when data type does not match. + * @throws ClassCastException + * when data type does not match. */ def getDate(i: Int): java.sql.Date = getAs[java.sql.Date](i) /** * Returns the value at position i of date type as java.time.LocalDate. * - * @throws ClassCastException when data type does not match. + * @throws ClassCastException + * when data type does not match. */ def getLocalDate(i: Int): java.time.LocalDate = getAs[java.time.LocalDate](i) /** * Returns the value at position i of date type as java.sql.Timestamp. * - * @throws ClassCastException when data type does not match. + * @throws ClassCastException + * when data type does not match. */ def getTimestamp(i: Int): java.sql.Timestamp = getAs[java.sql.Timestamp](i) /** * Returns the value at position i of date type as java.time.Instant. * - * @throws ClassCastException when data type does not match. + * @throws ClassCastException + * when data type does not match. */ def getInstant(i: Int): java.time.Instant = getAs[java.time.Instant](i) /** * Returns the value at position i of array type as a Scala Seq. * - * @throws ClassCastException when data type does not match. + * @throws ClassCastException + * when data type does not match. */ def getSeq[T](i: Int): Seq[T] = { getAs[scala.collection.Seq[T]](i) match { @@ -334,7 +355,8 @@ trait Row extends Serializable { /** * Returns the value at position i of array type as `java.util.List`. * - * @throws ClassCastException when data type does not match. + * @throws ClassCastException + * when data type does not match. */ def getList[T](i: Int): java.util.List[T] = getSeq[T](i).asJava @@ -342,14 +364,16 @@ trait Row extends Serializable { /** * Returns the value at position i of map type as a Scala Map. * - * @throws ClassCastException when data type does not match. + * @throws ClassCastException + * when data type does not match. */ def getMap[K, V](i: Int): scala.collection.Map[K, V] = getAs[Map[K, V]](i) /** * Returns the value at position i of array type as a `java.util.Map`. * - * @throws ClassCastException when data type does not match. + * @throws ClassCastException + * when data type does not match. */ def getJavaMap[K, V](i: Int): java.util.Map[K, V] = getMap[K, V](i).asJava @@ -357,48 +381,56 @@ trait Row extends Serializable { /** * Returns the value at position i of struct type as a [[Row]] object. * - * @throws ClassCastException when data type does not match. + * @throws ClassCastException + * when data type does not match. */ def getStruct(i: Int): Row = getAs[Row](i) /** - * Returns the value at position i. - * For primitive types if value is null it returns 'zero value' specific for primitive - * i.e. 0 for Int - use isNullAt to ensure that value is not null + * Returns the value at position i. For primitive types if value is null it returns 'zero value' + * specific for primitive i.e. 0 for Int - use isNullAt to ensure that value is not null * - * @throws ClassCastException when data type does not match. + * @throws ClassCastException + * when data type does not match. */ def getAs[T](i: Int): T = get(i).asInstanceOf[T] /** - * Returns the value of a given fieldName. - * For primitive types if value is null it returns 'zero value' specific for primitive - * i.e. 0 for Int - use isNullAt to ensure that value is not null + * Returns the value of a given fieldName. For primitive types if value is null it returns 'zero + * value' specific for primitive i.e. 0 for Int - use isNullAt to ensure that value is not null * - * @throws UnsupportedOperationException when schema is not defined. - * @throws IllegalArgumentException when fieldName do not exist. - * @throws ClassCastException when data type does not match. + * @throws UnsupportedOperationException + * when schema is not defined. + * @throws IllegalArgumentException + * when fieldName do not exist. + * @throws ClassCastException + * when data type does not match. */ def getAs[T](fieldName: String): T = getAs[T](fieldIndex(fieldName)) /** * Returns the index of a given field name. * - * @throws UnsupportedOperationException when schema is not defined. - * @throws IllegalArgumentException when a field `name` does not exist. + * @throws UnsupportedOperationException + * when schema is not defined. + * @throws IllegalArgumentException + * when a field `name` does not exist. */ def fieldIndex(name: String): Int = { throw DataTypeErrors.fieldIndexOnRowWithoutSchemaError(fieldName = name) } /** - * Returns a Map consisting of names and values for the requested fieldNames - * For primitive types if value is null it returns 'zero value' specific for primitive - * i.e. 0 for Int - use isNullAt to ensure that value is not null + * Returns a Map consisting of names and values for the requested fieldNames For primitive types + * if value is null it returns 'zero value' specific for primitive i.e. 0 for Int - use isNullAt + * to ensure that value is not null * - * @throws UnsupportedOperationException when schema is not defined. - * @throws IllegalArgumentException when fieldName do not exist. - * @throws ClassCastException when data type does not match. + * @throws UnsupportedOperationException + * when schema is not defined. + * @throws IllegalArgumentException + * when fieldName do not exist. + * @throws ClassCastException + * when data type does not match. */ def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = { fieldNames.map { name => @@ -445,24 +477,25 @@ trait Row extends Serializable { o1 match { case b1: Array[Byte] => if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { return false } case f1: Float if java.lang.Float.isNaN(f1) => - if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + if (!o2.isInstanceOf[Float] || !java.lang.Float.isNaN(o2.asInstanceOf[Float])) { return false } case d1: Double if java.lang.Double.isNaN(d1) => - if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + if (!o2.isInstanceOf[Double] || !java.lang.Double.isNaN(o2.asInstanceOf[Double])) { return false } case d1: java.math.BigDecimal if o2.isInstanceOf[java.math.BigDecimal] => if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) { return false } - case _ => if (o1 != o2) { - return false - } + case _ => + if (o1 != o2) { + return false + } } } i += 1 @@ -505,8 +538,8 @@ trait Row extends Serializable { def mkString(sep: String): String = mkString("", sep, "") /** - * Displays all elements of this traversable or iterator in a string using - * start, end, and separator strings. + * Displays all elements of this traversable or iterator in a string using start, end, and + * separator strings. */ def mkString(start: String, sep: String, end: String): String = { val n = length @@ -528,9 +561,12 @@ trait Row extends Serializable { /** * Returns the value at position i. * - * @throws UnsupportedOperationException when schema is not defined. - * @throws ClassCastException when data type does not match. - * @throws org.apache.spark.SparkRuntimeException when value is null. + * @throws UnsupportedOperationException + * when schema is not defined. + * @throws ClassCastException + * when data type does not match. + * @throws org.apache.spark.SparkRuntimeException + * when value is null. */ private def getAnyValAs[T <: AnyVal](i: Int): T = if (isNullAt(i)) throw DataTypeErrors.valueIsNullError(i) @@ -556,7 +592,8 @@ trait Row extends Serializable { * Note that this only supports the data types that are also supported by * [[org.apache.spark.sql.catalyst.encoders.RowEncoder]]. * - * @return the JSON representation of the row. + * @return + * the JSON representation of the row. */ private[sql] def jsonValue: JValue = { require(schema != null, "JSON serialization requires a non-null schema.") @@ -598,13 +635,12 @@ trait Row extends Serializable { case (s: Seq[_], ArrayType(elementType, _)) => iteratorToJsonArray(s.iterator, elementType) case (m: Map[String @unchecked, _], MapType(StringType, valueType, _)) => - new JObject(m.toList.sortBy(_._1).map { - case (k, v) => k -> toJson(v, valueType) + new JObject(m.toList.sortBy(_._1).map { case (k, v) => + k -> toJson(v, valueType) }) case (m: Map[_, _], MapType(keyType, valueType, _)) => - new JArray(m.iterator.map { - case (k, v) => - new JObject("key" -> toJson(k, keyType) :: "value" -> toJson(v, valueType) :: Nil) + new JArray(m.iterator.map { case (k, v) => + new JObject("key" -> toJson(k, keyType) :: "value" -> toJson(v, valueType) :: Nil) }.toList) case (row: Row, schema: StructType) => var n = 0 @@ -618,13 +654,13 @@ trait Row extends Serializable { new JObject(elements.toList) case (v: Any, udt: UserDefinedType[Any @unchecked]) => toJson(UDTUtils.toRow(v, udt), udt.sqlType) - case _ => throw new SparkIllegalArgumentException( - errorClass = "FAILED_ROW_TO_JSON", - messageParameters = Map( - "value" -> toSQLValue(value.toString), - "class" -> value.getClass.toString, - "sqlType" -> toSQLType(dataType.toString)) - ) + case _ => + throw new SparkIllegalArgumentException( + errorClass = "FAILED_ROW_TO_JSON", + messageParameters = Map( + "value" -> toSQLValue(value.toString), + "class" -> value.getClass.toString, + "sqlType" -> toSQLType(dataType.toString))) } toJson(this, schema) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/api/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala new file mode 100644 index 0000000000000..23a2774ebc3a5 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.annotation.Stable + +/** + * Runtime configuration interface for Spark. To access this, use `SparkSession.conf`. + * + * Options set here are automatically propagated to the Hadoop configuration during I/O. + * + * @since 2.0.0 + */ +@Stable +abstract class RuntimeConfig { + + /** + * Sets the given Spark runtime configuration property. + * + * @since 2.0.0 + */ + def set(key: String, value: String): Unit + + /** + * Sets the given Spark runtime configuration property. + * + * @since 2.0.0 + */ + def set(key: String, value: Boolean): Unit = { + set(key, value.toString) + } + + /** + * Sets the given Spark runtime configuration property. + * + * @since 2.0.0 + */ + def set(key: String, value: Long): Unit = { + set(key, value.toString) + } + + /** + * Returns the value of Spark runtime configuration property for the given key. + * + * @throws java.util.NoSuchElementException + * if the key is not set and does not have a default value + * @since 2.0.0 + */ + @throws[NoSuchElementException]("if the key is not set") + def get(key: String): String + + /** + * Returns the value of Spark runtime configuration property for the given key. + * + * @since 2.0.0 + */ + def get(key: String, default: String): String + + /** + * Returns all properties set in this conf. + * + * @since 2.0.0 + */ + def getAll: Map[String, String] + + /** + * Returns the value of Spark runtime configuration property for the given key. + * + * @since 2.0.0 + */ + def getOption(key: String): Option[String] + + /** + * Resets the configuration property for the given key. + * + * @since 2.0.0 + */ + def unset(key: String): Unit + + /** + * Indicates whether the configuration property with the given key is modifiable in the current + * session. + * + * @return + * `true` if the configuration property is modifiable. For static SQL, Spark Core, invalid + * (not existing) and other non-modifiable configuration properties, the returned value is + * `false`. + * @since 2.4.0 + */ + def isModifiable(key: String): Boolean +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Catalog.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/Catalog.scala new file mode 100644 index 0000000000000..a0f51d30dc572 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Catalog.scala @@ -0,0 +1,682 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api + +import scala.jdk.CollectionConverters._ + +import _root_.java.util + +import org.apache.spark.annotation.Stable +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.catalog.{CatalogMetadata, Column, Database, Function, Table} +import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.StorageLevel + +/** + * Catalog interface for Spark. To access this, use `SparkSession.catalog`. + * + * @since 2.0.0 + */ +@Stable +abstract class Catalog { + + /** + * Returns the current database (namespace) in this session. + * + * @since 2.0.0 + */ + def currentDatabase: String + + /** + * Sets the current database (namespace) in this session. + * + * @since 2.0.0 + */ + def setCurrentDatabase(dbName: String): Unit + + /** + * Returns a list of databases (namespaces) available within the current catalog. + * + * @since 2.0.0 + */ + def listDatabases(): Dataset[Database] + + /** + * Returns a list of databases (namespaces) which name match the specify pattern and available + * within the current catalog. + * + * @since 3.5.0 + */ + def listDatabases(pattern: String): Dataset[Database] + + /** + * Returns a list of tables/views in the current database (namespace). This includes all + * temporary views. + * + * @since 2.0.0 + */ + def listTables(): Dataset[Table] + + /** + * Returns a list of tables/views in the specified database (namespace) (the name can be + * qualified with catalog). This includes all temporary views. + * + * @since 2.0.0 + */ + @throws[AnalysisException]("database does not exist") + def listTables(dbName: String): Dataset[Table] + + /** + * Returns a list of tables/views in the specified database (namespace) which name match the + * specify pattern (the name can be qualified with catalog). This includes all temporary views. + * + * @since 3.5.0 + */ + @throws[AnalysisException]("database does not exist") + def listTables(dbName: String, pattern: String): Dataset[Table] + + /** + * Returns a list of functions registered in the current database (namespace). This includes all + * temporary functions. + * + * @since 2.0.0 + */ + def listFunctions(): Dataset[Function] + + /** + * Returns a list of functions registered in the specified database (namespace) (the name can be + * qualified with catalog). This includes all built-in and temporary functions. + * + * @since 2.0.0 + */ + @throws[AnalysisException]("database does not exist") + def listFunctions(dbName: String): Dataset[Function] + + /** + * Returns a list of functions registered in the specified database (namespace) which name match + * the specify pattern (the name can be qualified with catalog). This includes all built-in and + * temporary functions. + * + * @since 3.5.0 + */ + @throws[AnalysisException]("database does not exist") + def listFunctions(dbName: String, pattern: String): Dataset[Function] + + /** + * Returns a list of columns for the given table/view or temporary view. + * + * @param tableName + * is either a qualified or unqualified name that designates a table/view. It follows the same + * resolution rule with SQL: search for temp views first then table/views in the current + * database (namespace). + * @since 2.0.0 + */ + @throws[AnalysisException]("table does not exist") + def listColumns(tableName: String): Dataset[Column] + + /** + * Returns a list of columns for the given table/view in the specified database under the Hive + * Metastore. + * + * To list columns for table/view in other catalogs, please use `listColumns(tableName)` with + * qualified table/view name instead. + * + * @param dbName + * is an unqualified name that designates a database. + * @param tableName + * is an unqualified name that designates a table/view. + * @since 2.0.0 + */ + @throws[AnalysisException]("database or table does not exist") + def listColumns(dbName: String, tableName: String): Dataset[Column] + + /** + * Get the database (namespace) with the specified name (can be qualified with catalog). This + * throws an AnalysisException when the database (namespace) cannot be found. + * + * @since 2.1.0 + */ + @throws[AnalysisException]("database does not exist") + def getDatabase(dbName: String): Database + + /** + * Get the table or view with the specified name. This table can be a temporary view or a + * table/view. This throws an AnalysisException when no Table can be found. + * + * @param tableName + * is either a qualified or unqualified name that designates a table/view. It follows the same + * resolution rule with SQL: search for temp views first then table/views in the current + * database (namespace). + * @since 2.1.0 + */ + @throws[AnalysisException]("table does not exist") + def getTable(tableName: String): Table + + /** + * Get the table or view with the specified name in the specified database under the Hive + * Metastore. This throws an AnalysisException when no Table can be found. + * + * To get table/view in other catalogs, please use `getTable(tableName)` with qualified + * table/view name instead. + * + * @since 2.1.0 + */ + @throws[AnalysisException]("database or table does not exist") + def getTable(dbName: String, tableName: String): Table + + /** + * Get the function with the specified name. This function can be a temporary function or a + * function. This throws an AnalysisException when the function cannot be found. + * + * @param functionName + * is either a qualified or unqualified name that designates a function. It follows the same + * resolution rule with SQL: search for built-in/temp functions first then functions in the + * current database (namespace). + * @since 2.1.0 + */ + @throws[AnalysisException]("function does not exist") + def getFunction(functionName: String): Function + + /** + * Get the function with the specified name in the specified database under the Hive Metastore. + * This throws an AnalysisException when the function cannot be found. + * + * To get functions in other catalogs, please use `getFunction(functionName)` with qualified + * function name instead. + * + * @param dbName + * is an unqualified name that designates a database. + * @param functionName + * is an unqualified name that designates a function in the specified database + * @since 2.1.0 + */ + @throws[AnalysisException]("database or function does not exist") + def getFunction(dbName: String, functionName: String): Function + + /** + * Check if the database (namespace) with the specified name exists (the name can be qualified + * with catalog). + * + * @since 2.1.0 + */ + def databaseExists(dbName: String): Boolean + + /** + * Check if the table or view with the specified name exists. This can either be a temporary + * view or a table/view. + * + * @param tableName + * is either a qualified or unqualified name that designates a table/view. It follows the same + * resolution rule with SQL: search for temp views first then table/views in the current + * database (namespace). + * @since 2.1.0 + */ + def tableExists(tableName: String): Boolean + + /** + * Check if the table or view with the specified name exists in the specified database under the + * Hive Metastore. + * + * To check existence of table/view in other catalogs, please use `tableExists(tableName)` with + * qualified table/view name instead. + * + * @param dbName + * is an unqualified name that designates a database. + * @param tableName + * is an unqualified name that designates a table. + * @since 2.1.0 + */ + def tableExists(dbName: String, tableName: String): Boolean + + /** + * Check if the function with the specified name exists. This can either be a temporary function + * or a function. + * + * @param functionName + * is either a qualified or unqualified name that designates a function. It follows the same + * resolution rule with SQL: search for built-in/temp functions first then functions in the + * current database (namespace). + * @since 2.1.0 + */ + def functionExists(functionName: String): Boolean + + /** + * Check if the function with the specified name exists in the specified database under the Hive + * Metastore. + * + * To check existence of functions in other catalogs, please use `functionExists(functionName)` + * with qualified function name instead. + * + * @param dbName + * is an unqualified name that designates a database. + * @param functionName + * is an unqualified name that designates a function. + * @since 2.1.0 + */ + def functionExists(dbName: String, functionName: String): Boolean + + /** + * Creates a table from the given path and returns the corresponding DataFrame. It will use the + * default data source configured by spark.sql.sources.default. + * + * @param tableName + * is either a qualified or unqualified name that designates a table. If no database + * identifier is provided, it refers to a table in the current database. + * @since 2.0.0 + */ + @deprecated("use createTable instead.", "2.2.0") + def createExternalTable(tableName: String, path: String): Dataset[Row] = { + createTable(tableName, path) + } + + /** + * Creates a table from the given path and returns the corresponding DataFrame. It will use the + * default data source configured by spark.sql.sources.default. + * + * @param tableName + * is either a qualified or unqualified name that designates a table. If no database + * identifier is provided, it refers to a table in the current database. + * @since 2.2.0 + */ + def createTable(tableName: String, path: String): Dataset[Row] + + /** + * Creates a table from the given path based on a data source and returns the corresponding + * DataFrame. + * + * @param tableName + * is either a qualified or unqualified name that designates a table. If no database + * identifier is provided, it refers to a table in the current database. + * @since 2.0.0 + */ + @deprecated("use createTable instead.", "2.2.0") + def createExternalTable(tableName: String, path: String, source: String): Dataset[Row] = { + createTable(tableName, path, source) + } + + /** + * Creates a table from the given path based on a data source and returns the corresponding + * DataFrame. + * + * @param tableName + * is either a qualified or unqualified name that designates a table. If no database + * identifier is provided, it refers to a table in the current database. + * @since 2.2.0 + */ + def createTable(tableName: String, path: String, source: String): Dataset[Row] + + /** + * Creates a table from the given path based on a data source and a set of options. Then, + * returns the corresponding DataFrame. + * + * @param tableName + * is either a qualified or unqualified name that designates a table. If no database + * identifier is provided, it refers to a table in the current database. + * @since 2.0.0 + */ + @deprecated("use createTable instead.", "2.2.0") + def createExternalTable( + tableName: String, + source: String, + options: util.Map[String, String]): Dataset[Row] = { + createTable(tableName, source, options) + } + + /** + * Creates a table based on the dataset in a data source and a set of options. Then, returns the + * corresponding DataFrame. + * + * @param tableName + * is either a qualified or unqualified name that designates a table. If no database + * identifier is provided, it refers to a table in the current database. + * @since 2.2.0 + */ + def createTable( + tableName: String, + source: String, + options: util.Map[String, String]): Dataset[Row] = { + createTable(tableName, source, options.asScala.toMap) + } + + /** + * (Scala-specific) Creates a table from the given path based on a data source and a set of + * options. Then, returns the corresponding DataFrame. + * + * @param tableName + * is either a qualified or unqualified name that designates a table. If no database + * identifier is provided, it refers to a table in the current database. + * @since 2.0.0 + */ + @deprecated("use createTable instead.", "2.2.0") + def createExternalTable( + tableName: String, + source: String, + options: Map[String, String]): Dataset[Row] = { + createTable(tableName, source, options) + } + + /** + * (Scala-specific) Creates a table based on the dataset in a data source and a set of options. + * Then, returns the corresponding DataFrame. + * + * @param tableName + * is either a qualified or unqualified name that designates a table. If no database + * identifier is provided, it refers to a table in the current database. + * @since 2.2.0 + */ + def createTable(tableName: String, source: String, options: Map[String, String]): Dataset[Row] + + /** + * Create a table from the given path based on a data source, a schema and a set of options. + * Then, returns the corresponding DataFrame. + * + * @param tableName + * is either a qualified or unqualified name that designates a table. If no database + * identifier is provided, it refers to a table in the current database. + * @since 2.0.0 + */ + @deprecated("use createTable instead.", "2.2.0") + def createExternalTable( + tableName: String, + source: String, + schema: StructType, + options: util.Map[String, String]): Dataset[Row] = { + createTable(tableName, source, schema, options) + } + + /** + * Creates a table based on the dataset in a data source and a set of options. Then, returns the + * corresponding DataFrame. + * + * @param tableName + * is either a qualified or unqualified name that designates a table. If no database + * identifier is provided, it refers to a table in the current database. + * @since 3.1.0 + */ + def createTable( + tableName: String, + source: String, + description: String, + options: util.Map[String, String]): Dataset[Row] = { + createTable( + tableName, + source = source, + description = description, + options = options.asScala.toMap) + } + + /** + * (Scala-specific) Creates a table based on the dataset in a data source and a set of options. + * Then, returns the corresponding DataFrame. + * + * @param tableName + * is either a qualified or unqualified name that designates a table. If no database + * identifier is provided, it refers to a table in the current database. + * @since 3.1.0 + */ + def createTable( + tableName: String, + source: String, + description: String, + options: Map[String, String]): Dataset[Row] + + /** + * Create a table based on the dataset in a data source, a schema and a set of options. Then, + * returns the corresponding DataFrame. + * + * @param tableName + * is either a qualified or unqualified name that designates a table. If no database + * identifier is provided, it refers to a table in the current database. + * @since 2.2.0 + */ + def createTable( + tableName: String, + source: String, + schema: StructType, + options: util.Map[String, String]): Dataset[Row] = { + createTable(tableName, source, schema, options.asScala.toMap) + } + + /** + * (Scala-specific) Create a table from the given path based on a data source, a schema and a + * set of options. Then, returns the corresponding DataFrame. + * + * @param tableName + * is either a qualified or unqualified name that designates a table. If no database + * identifier is provided, it refers to a table in the current database. + * @since 2.0.0 + */ + @deprecated("use createTable instead.", "2.2.0") + def createExternalTable( + tableName: String, + source: String, + schema: StructType, + options: Map[String, String]): Dataset[Row] = { + createTable(tableName, source, schema, options) + } + + /** + * (Scala-specific) Create a table based on the dataset in a data source, a schema and a set of + * options. Then, returns the corresponding DataFrame. + * + * @param tableName + * is either a qualified or unqualified name that designates a table. If no database + * identifier is provided, it refers to a table in the current database. + * @since 2.2.0 + */ + def createTable( + tableName: String, + source: String, + schema: StructType, + options: Map[String, String]): Dataset[Row] + + /** + * Create a table based on the dataset in a data source, a schema and a set of options. Then, + * returns the corresponding DataFrame. + * + * @param tableName + * is either a qualified or unqualified name that designates a table. If no database + * identifier is provided, it refers to a table in the current database. + * @since 3.1.0 + */ + def createTable( + tableName: String, + source: String, + schema: StructType, + description: String, + options: util.Map[String, String]): Dataset[Row] = { + createTable( + tableName, + source = source, + schema = schema, + description = description, + options = options.asScala.toMap) + } + + /** + * (Scala-specific) Create a table based on the dataset in a data source, a schema and a set of + * options. Then, returns the corresponding DataFrame. + * + * @param tableName + * is either a qualified or unqualified name that designates a table. If no database + * identifier is provided, it refers to a table in the current database. + * @since 3.1.0 + */ + def createTable( + tableName: String, + source: String, + schema: StructType, + description: String, + options: Map[String, String]): Dataset[Row] + + /** + * Drops the local temporary view with the given view name in the catalog. If the view has been + * cached before, then it will also be uncached. + * + * Local temporary view is session-scoped. Its lifetime is the lifetime of the session that + * created it, i.e. it will be automatically dropped when the session terminates. It's not tied + * to any databases, i.e. we can't use `db1.view1` to reference a local temporary view. + * + * Note that, the return type of this method was Unit in Spark 2.0, but changed to Boolean in + * Spark 2.1. + * + * @param viewName + * the name of the temporary view to be dropped. + * @return + * true if the view is dropped successfully, false otherwise. + * @since 2.0.0 + */ + def dropTempView(viewName: String): Boolean + + /** + * Drops the global temporary view with the given view name in the catalog. If the view has been + * cached before, then it will also be uncached. + * + * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark + * application, i.e. it will be automatically dropped when the application terminates. It's tied + * to a system preserved database `global_temp`, and we must use the qualified name to refer a + * global temp view, e.g. `SELECT * FROM global_temp.view1`. + * + * @param viewName + * the unqualified name of the temporary view to be dropped. + * @return + * true if the view is dropped successfully, false otherwise. + * @since 2.1.0 + */ + def dropGlobalTempView(viewName: String): Boolean + + /** + * Recovers all the partitions in the directory of a table and update the catalog. Only works + * with a partitioned table, and not a view. + * + * @param tableName + * is either a qualified or unqualified name that designates a table. If no database + * identifier is provided, it refers to a table in the current database. + * @since 2.1.1 + */ + def recoverPartitions(tableName: String): Unit + + /** + * Returns true if the table is currently cached in-memory. + * + * @param tableName + * is either a qualified or unqualified name that designates a table/view. If no database + * identifier is provided, it refers to a temporary view or a table/view in the current + * database. + * @since 2.0.0 + */ + def isCached(tableName: String): Boolean + + /** + * Caches the specified table in-memory. + * + * @param tableName + * is either a qualified or unqualified name that designates a table/view. If no database + * identifier is provided, it refers to a temporary view or a table/view in the current + * database. + * @since 2.0.0 + */ + def cacheTable(tableName: String): Unit + + /** + * Caches the specified table with the given storage level. + * + * @param tableName + * is either a qualified or unqualified name that designates a table/view. If no database + * identifier is provided, it refers to a temporary view or a table/view in the current + * database. + * @param storageLevel + * storage level to cache table. + * @since 2.3.0 + */ + def cacheTable(tableName: String, storageLevel: StorageLevel): Unit + + /** + * Removes the specified table from the in-memory cache. + * + * @param tableName + * is either a qualified or unqualified name that designates a table/view. If no database + * identifier is provided, it refers to a temporary view or a table/view in the current + * database. + * @since 2.0.0 + */ + def uncacheTable(tableName: String): Unit + + /** + * Removes all cached tables from the in-memory cache. + * + * @since 2.0.0 + */ + def clearCache(): Unit + + /** + * Invalidates and refreshes all the cached data and metadata of the given table. For + * performance reasons, Spark SQL or the external data source library it uses might cache + * certain metadata about a table, such as the location of blocks. When those change outside of + * Spark SQL, users should call this function to invalidate the cache. + * + * If this table is cached as an InMemoryRelation, drop the original cached version and make the + * new version cached lazily. + * + * @param tableName + * is either a qualified or unqualified name that designates a table/view. If no database + * identifier is provided, it refers to a temporary view or a table/view in the current + * database. + * @since 2.0.0 + */ + def refreshTable(tableName: String): Unit + + /** + * Invalidates and refreshes all the cached data (and the associated metadata) for any `Dataset` + * that contains the given data source path. Path matching is by checking for sub-directories, + * i.e. "/" would invalidate everything that is cached and "/test/parent" would invalidate + * everything that is a subdirectory of "/test/parent". + * + * @since 2.0.0 + */ + def refreshByPath(path: String): Unit + + /** + * Returns the current catalog in this session. + * + * @since 3.4.0 + */ + def currentCatalog(): String + + /** + * Sets the current catalog in this session. + * + * @since 3.4.0 + */ + def setCurrentCatalog(catalogName: String): Unit + + /** + * Returns a list of catalogs available in this session. + * + * @since 3.4.0 + */ + def listCatalogs(): Dataset[CatalogMetadata] + + /** + * Returns a list of catalogs which name match the specify pattern and available in this + * session. + * + * @since 3.5.0 + */ + def listCatalogs(pattern: String): Dataset[CatalogMetadata] +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala index 7400f90992d8f..ef6cc64c058a4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala @@ -30,32 +30,32 @@ import org.apache.spark.util.ArrayImplicits._ * @since 1.3.1 */ @Stable -abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { +abstract class DataFrameNaFunctions { /** * Returns a new `DataFrame` that drops rows containing any null or NaN values. * * @since 1.3.1 */ - def drop(): DS[Row] = drop("any") + def drop(): Dataset[Row] = drop("any") /** * Returns a new `DataFrame` that drops rows containing null or NaN values. * - * If `how` is "any", then drop rows containing any null or NaN values. - * If `how` is "all", then drop rows only if every column is null or NaN for that row. + * If `how` is "any", then drop rows containing any null or NaN values. If `how` is "all", then + * drop rows only if every column is null or NaN for that row. * * @since 1.3.1 */ - def drop(how: String): DS[Row] = drop(toMinNonNulls(how)) + def drop(how: String): Dataset[Row] = drop(toMinNonNulls(how)) /** - * Returns a new `DataFrame` that drops rows containing any null or NaN values - * in the specified columns. + * Returns a new `DataFrame` that drops rows containing any null or NaN values in the specified + * columns. * * @since 1.3.1 */ - def drop(cols: Array[String]): DS[Row] = drop(cols.toImmutableArraySeq) + def drop(cols: Array[String]): Dataset[Row] = drop(cols.toImmutableArraySeq) /** * (Scala-specific) Returns a new `DataFrame` that drops rows containing any null or NaN values @@ -63,54 +63,54 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def drop(cols: Seq[String]): DS[Row] = drop(cols.size, cols) + def drop(cols: Seq[String]): Dataset[Row] = drop(cols.size, cols) /** - * Returns a new `DataFrame` that drops rows containing null or NaN values - * in the specified columns. + * Returns a new `DataFrame` that drops rows containing null or NaN values in the specified + * columns. * * If `how` is "any", then drop rows containing any null or NaN values in the specified columns. * If `how` is "all", then drop rows only if every specified column is null or NaN for that row. * * @since 1.3.1 */ - def drop(how: String, cols: Array[String]): DS[Row] = drop(how, cols.toImmutableArraySeq) + def drop(how: String, cols: Array[String]): Dataset[Row] = drop(how, cols.toImmutableArraySeq) /** - * (Scala-specific) Returns a new `DataFrame` that drops rows containing null or NaN values - * in the specified columns. + * (Scala-specific) Returns a new `DataFrame` that drops rows containing null or NaN values in + * the specified columns. * * If `how` is "any", then drop rows containing any null or NaN values in the specified columns. * If `how` is "all", then drop rows only if every specified column is null or NaN for that row. * * @since 1.3.1 */ - def drop(how: String, cols: Seq[String]): DS[Row] = drop(toMinNonNulls(how), cols) + def drop(how: String, cols: Seq[String]): Dataset[Row] = drop(toMinNonNulls(how), cols) /** - * Returns a new `DataFrame` that drops rows containing - * less than `minNonNulls` non-null and non-NaN values. + * Returns a new `DataFrame` that drops rows containing less than `minNonNulls` non-null and + * non-NaN values. * * @since 1.3.1 */ - def drop(minNonNulls: Int): DS[Row] = drop(Option(minNonNulls)) + def drop(minNonNulls: Int): Dataset[Row] = drop(Option(minNonNulls)) /** - * Returns a new `DataFrame` that drops rows containing - * less than `minNonNulls` non-null and non-NaN values in the specified columns. + * Returns a new `DataFrame` that drops rows containing less than `minNonNulls` non-null and + * non-NaN values in the specified columns. * * @since 1.3.1 */ - def drop(minNonNulls: Int, cols: Array[String]): DS[Row] = + def drop(minNonNulls: Int, cols: Array[String]): Dataset[Row] = drop(minNonNulls, cols.toImmutableArraySeq) /** - * (Scala-specific) Returns a new `DataFrame` that drops rows containing less than - * `minNonNulls` non-null and non-NaN values in the specified columns. + * (Scala-specific) Returns a new `DataFrame` that drops rows containing less than `minNonNulls` + * non-null and non-NaN values in the specified columns. * * @since 1.3.1 */ - def drop(minNonNulls: Int, cols: Seq[String]): DS[Row] = drop(Option(minNonNulls), cols) + def drop(minNonNulls: Int, cols: Seq[String]): Dataset[Row] = drop(Option(minNonNulls), cols) private def toMinNonNulls(how: String): Option[Int] = { how.toLowerCase(util.Locale.ROOT) match { @@ -120,45 +120,46 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { } } - protected def drop(minNonNulls: Option[Int]): DS[Row] + protected def drop(minNonNulls: Option[Int]): Dataset[Row] - protected def drop(minNonNulls: Option[Int], cols: Seq[String]): DS[Row] + protected def drop(minNonNulls: Option[Int], cols: Seq[String]): Dataset[Row] /** * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. * * @since 2.2.0 */ - def fill(value: Long): DS[Row] + def fill(value: Long): Dataset[Row] /** * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. * @since 1.3.1 */ - def fill(value: Double): DS[Row] + def fill(value: Double): Dataset[Row] /** * Returns a new `DataFrame` that replaces null values in string columns with `value`. * * @since 1.3.1 */ - def fill(value: String): DS[Row] + def fill(value: String): Dataset[Row] /** - * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. - * If a specified column is not a numeric column, it is ignored. + * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. If a + * specified column is not a numeric column, it is ignored. * * @since 2.2.0 */ - def fill(value: Long, cols: Array[String]): DS[Row] = fill(value, cols.toImmutableArraySeq) + def fill(value: Long, cols: Array[String]): Dataset[Row] = fill(value, cols.toImmutableArraySeq) /** - * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. - * If a specified column is not a numeric column, it is ignored. + * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. If a + * specified column is not a numeric column, it is ignored. * * @since 1.3.1 */ - def fill(value: Double, cols: Array[String]): DS[Row] = fill(value, cols.toImmutableArraySeq) + def fill(value: Double, cols: Array[String]): Dataset[Row] = + fill(value, cols.toImmutableArraySeq) /** * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified @@ -166,7 +167,7 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 2.2.0 */ - def fill(value: Long, cols: Seq[String]): DS[Row] + def fill(value: Long, cols: Seq[String]): Dataset[Row] /** * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified @@ -174,58 +175,58 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def fill(value: Double, cols: Seq[String]): DS[Row] - + def fill(value: Double, cols: Seq[String]): Dataset[Row] /** - * Returns a new `DataFrame` that replaces null values in specified string columns. - * If a specified column is not a string column, it is ignored. + * Returns a new `DataFrame` that replaces null values in specified string columns. If a + * specified column is not a string column, it is ignored. * * @since 1.3.1 */ - def fill(value: String, cols: Array[String]): DS[Row] = fill(value, cols.toImmutableArraySeq) + def fill(value: String, cols: Array[String]): Dataset[Row] = + fill(value, cols.toImmutableArraySeq) /** - * (Scala-specific) Returns a new `DataFrame` that replaces null values in - * specified string columns. If a specified column is not a string column, it is ignored. + * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified string + * columns. If a specified column is not a string column, it is ignored. * * @since 1.3.1 */ - def fill(value: String, cols: Seq[String]): DS[Row] + def fill(value: String, cols: Seq[String]): Dataset[Row] /** * Returns a new `DataFrame` that replaces null values in boolean columns with `value`. * * @since 2.3.0 */ - def fill(value: Boolean): DS[Row] + def fill(value: Boolean): Dataset[Row] /** - * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified - * boolean columns. If a specified column is not a boolean column, it is ignored. + * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified boolean + * columns. If a specified column is not a boolean column, it is ignored. * * @since 2.3.0 */ - def fill(value: Boolean, cols: Seq[String]): DS[Row] + def fill(value: Boolean, cols: Seq[String]): Dataset[Row] /** - * Returns a new `DataFrame` that replaces null values in specified boolean columns. - * If a specified column is not a boolean column, it is ignored. + * Returns a new `DataFrame` that replaces null values in specified boolean columns. If a + * specified column is not a boolean column, it is ignored. * * @since 2.3.0 */ - def fill(value: Boolean, cols: Array[String]): DS[Row] = fill(value, cols.toImmutableArraySeq) + def fill(value: Boolean, cols: Array[String]): Dataset[Row] = + fill(value, cols.toImmutableArraySeq) /** * Returns a new `DataFrame` that replaces null values. * - * The key of the map is the column name, and the value of the map is the replacement value. - * The value must be of the following type: - * `Integer`, `Long`, `Float`, `Double`, `String`, `Boolean`. - * Replacement values are cast to the column data type. + * The key of the map is the column name, and the value of the map is the replacement value. The + * value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`, + * `Boolean`. Replacement values are cast to the column data type. * - * For example, the following replaces null values in column "A" with string "unknown", and - * null values in column "B" with numeric value 1.0. + * For example, the following replaces null values in column "A" with string "unknown", and null + * values in column "B" with numeric value 1.0. * {{{ * import com.google.common.collect.ImmutableMap; * df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0)); @@ -233,17 +234,17 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def fill(valueMap: util.Map[String, Any]): DS[Row] = fillMap(valueMap.asScala.toSeq) + def fill(valueMap: util.Map[String, Any]): Dataset[Row] = fillMap(valueMap.asScala.toSeq) /** * (Scala-specific) Returns a new `DataFrame` that replaces null values. * - * The key of the map is the column name, and the value of the map is the replacement value. - * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`, `Boolean`. + * The key of the map is the column name, and the value of the map is the replacement value. The + * value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`, `Boolean`. * Replacement values are cast to the column data type. * - * For example, the following replaces null values in column "A" with string "unknown", and - * null values in column "B" with numeric value 1.0. + * For example, the following replaces null values in column "A" with string "unknown", and null + * values in column "B" with numeric value 1.0. * {{{ * df.na.fill(Map( * "A" -> "unknown", @@ -253,9 +254,9 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def fill(valueMap: Map[String, Any]): DS[Row] = fillMap(valueMap.toSeq) + def fill(valueMap: Map[String, Any]): Dataset[Row] = fillMap(valueMap.toSeq) - protected def fillMap(values: Seq[(String, Any)]): DS[Row] + protected def fillMap(values: Seq[(String, Any)]): Dataset[Row] /** * Replaces values matching keys in `replacement` map with the corresponding values. @@ -273,15 +274,16 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * df.na.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); * }}} * - * @param col name of the column to apply the value replacement. If `col` is "*", - * replacement is applied on all string, numeric or boolean columns. - * @param replacement value replacement map. Key and value of `replacement` map must have - * the same type, and can only be doubles, strings or booleans. - * The map value can have nulls. + * @param col + * name of the column to apply the value replacement. If `col` is "*", replacement is applied + * on all string, numeric or boolean columns. + * @param replacement + * value replacement map. Key and value of `replacement` map must have the same type, and can + * only be doubles, strings or booleans. The map value can have nulls. * * @since 1.3.1 */ - def replace[T](col: String, replacement: util.Map[T, T]): DS[Row] = { + def replace[T](col: String, replacement: util.Map[T, T]): Dataset[Row] = { replace[T](col, replacement.asScala.toMap) } @@ -298,15 +300,16 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * df.na.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); * }}} * - * @param cols list of columns to apply the value replacement. If `col` is "*", - * replacement is applied on all string, numeric or boolean columns. - * @param replacement value replacement map. Key and value of `replacement` map must have - * the same type, and can only be doubles, strings or booleans. - * The map value can have nulls. + * @param cols + * list of columns to apply the value replacement. If `col` is "*", replacement is applied on + * all string, numeric or boolean columns. + * @param replacement + * value replacement map. Key and value of `replacement` map must have the same type, and can + * only be doubles, strings or booleans. The map value can have nulls. * * @since 1.3.1 */ - def replace[T](cols: Array[String], replacement: util.Map[T, T]): DS[Row] = { + def replace[T](cols: Array[String], replacement: util.Map[T, T]): Dataset[Row] = { replace(cols.toImmutableArraySeq, replacement.asScala.toMap) } @@ -324,15 +327,16 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * df.na.replace("*", Map("UNKNOWN" -> "unnamed")); * }}} * - * @param col name of the column to apply the value replacement. If `col` is "*", - * replacement is applied on all string, numeric or boolean columns. - * @param replacement value replacement map. Key and value of `replacement` map must have - * the same type, and can only be doubles, strings or booleans. - * The map value can have nulls. + * @param col + * name of the column to apply the value replacement. If `col` is "*", replacement is applied + * on all string, numeric or boolean columns. + * @param replacement + * value replacement map. Key and value of `replacement` map must have the same type, and can + * only be doubles, strings or booleans. The map value can have nulls. * * @since 1.3.1 */ - def replace[T](col: String, replacement: Map[T, T]): DS[Row] + def replace[T](col: String, replacement: Map[T, T]): Dataset[Row] /** * (Scala-specific) Replaces values matching keys in `replacement` map. @@ -345,13 +349,14 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed")); * }}} * - * @param cols list of columns to apply the value replacement. If `col` is "*", - * replacement is applied on all string, numeric or boolean columns. - * @param replacement value replacement map. Key and value of `replacement` map must have - * the same type, and can only be doubles, strings or booleans. - * The map value can have nulls. + * @param cols + * list of columns to apply the value replacement. If `col` is "*", replacement is applied on + * all string, numeric or boolean columns. + * @param replacement + * value replacement map. Key and value of `replacement` map must have the same type, and can + * only be doubles, strings or booleans. The map value can have nulls. * * @since 1.3.1 */ - def replace[T](cols: Seq[String], replacement: Map[T, T]): DS[Row] + def replace[T](cols: Seq[String], replacement: Map[T, T]): Dataset[Row] } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala new file mode 100644 index 0000000000000..c101c52fd0662 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala @@ -0,0 +1,571 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.api + +import scala.jdk.CollectionConverters._ + +import _root_.java.util + +import org.apache.spark.annotation.Stable +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, SparkCharVarcharUtils} +import org.apache.spark.sql.errors.DataTypeErrors +import org.apache.spark.sql.types.StructType + +/** + * Interface used to load a [[Dataset]] from external storage systems (e.g. file systems, + * key-value stores, etc). Use `SparkSession.read` to access this. + * + * @since 1.4.0 + */ +@Stable +abstract class DataFrameReader { + type DS[U] <: Dataset[U] + + /** + * Specifies the input data source format. + * + * @since 1.4.0 + */ + def format(source: String): this.type = { + this.source = source + this + } + + /** + * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema + * automatically from data. By specifying the schema here, the underlying data source can skip + * the schema inference step, and thus speed up data loading. + * + * @since 1.4.0 + */ + def schema(schema: StructType): this.type = { + if (schema != null) { + val replaced = SparkCharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] + this.userSpecifiedSchema = Option(replaced) + validateSingleVariantColumn() + } + this + } + + /** + * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) + * can infer the input schema automatically from data. By specifying the schema here, the + * underlying data source can skip the schema inference step, and thus speed up data loading. + * + * {{{ + * spark.read.schema("a INT, b STRING, c DOUBLE").csv("test.csv") + * }}} + * + * @since 2.3.0 + */ + def schema(schemaString: String): this.type = schema(StructType.fromDDL(schemaString)) + + /** + * Adds an input option for the underlying data source. + * + * All options are maintained in a case-insensitive way in terms of key names. If a new option + * has the same key case-insensitively, it will override the existing option. + * + * @since 1.4.0 + */ + def option(key: String, value: String): this.type = { + this.extraOptions = this.extraOptions + (key -> value) + validateSingleVariantColumn() + this + } + + /** + * Adds an input option for the underlying data source. + * + * All options are maintained in a case-insensitive way in terms of key names. If a new option + * has the same key case-insensitively, it will override the existing option. + * + * @since 2.0.0 + */ + def option(key: String, value: Boolean): this.type = option(key, value.toString) + + /** + * Adds an input option for the underlying data source. + * + * All options are maintained in a case-insensitive way in terms of key names. If a new option + * has the same key case-insensitively, it will override the existing option. + * + * @since 2.0.0 + */ + def option(key: String, value: Long): this.type = option(key, value.toString) + + /** + * Adds an input option for the underlying data source. + * + * All options are maintained in a case-insensitive way in terms of key names. If a new option + * has the same key case-insensitively, it will override the existing option. + * + * @since 2.0.0 + */ + def option(key: String, value: Double): this.type = option(key, value.toString) + + /** + * (Scala-specific) Adds input options for the underlying data source. + * + * All options are maintained in a case-insensitive way in terms of key names. If a new option + * has the same key case-insensitively, it will override the existing option. + * + * @since 1.4.0 + */ + def options(options: scala.collection.Map[String, String]): this.type = { + this.extraOptions ++= options + validateSingleVariantColumn() + this + } + + /** + * Adds input options for the underlying data source. + * + * All options are maintained in a case-insensitive way in terms of key names. If a new option + * has the same key case-insensitively, it will override the existing option. + * + * @since 1.4.0 + */ + def options(opts: util.Map[String, String]): this.type = options(opts.asScala) + + /** + * Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external + * key-value stores). + * + * @since 1.4.0 + */ + def load(): Dataset[Row] + + /** + * Loads input in as a `DataFrame`, for data sources that require a path (e.g. data backed by a + * local or distributed file system). + * + * @since 1.4.0 + */ + def load(path: String): Dataset[Row] + + /** + * Loads input in as a `DataFrame`, for data sources that support multiple paths. Only works if + * the source is a HadoopFsRelationProvider. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def load(paths: String*): Dataset[Row] + + /** + * Construct a `DataFrame` representing the database table accessible via JDBC URL url named + * table and connection properties. + * + * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC + * in + * Data Source Option in the version you use. + * + * @since 1.4.0 + */ + def jdbc(url: String, table: String, properties: util.Properties): Dataset[Row] = { + assertNoSpecifiedSchema("jdbc") + // properties should override settings in extraOptions. + this.extraOptions ++= properties.asScala + // explicit url and dbtable should override all + this.extraOptions ++= Seq("url" -> url, "dbtable" -> table) + format("jdbc").load() + } + + // scalastyle:off line.size.limit + /** + * Construct a `DataFrame` representing the database table accessible via JDBC URL url named + * table. Partitions of the table will be retrieved in parallel based on the parameters passed + * to this function. + * + * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash + * your external database systems. + * + * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC + * in + * Data Source Option in the version you use. + * + * @param table + * Name of the table in the external database. + * @param columnName + * Alias of `partitionColumn` option. Refer to `partitionColumn` in + * Data Source Option in the version you use. + * @param connectionProperties + * JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least + * a "user" and "password" property should be included. "fetchsize" can be used to control the + * number of rows per fetch and "queryTimeout" can be used to wait for a Statement object to + * execute to the given number of seconds. + * @since 1.4.0 + */ + // scalastyle:on line.size.limit + def jdbc( + url: String, + table: String, + columnName: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int, + connectionProperties: util.Properties): Dataset[Row] = { + // columnName, lowerBound, upperBound and numPartitions override settings in extraOptions. + this.extraOptions ++= Map( + "partitionColumn" -> columnName, + "lowerBound" -> lowerBound.toString, + "upperBound" -> upperBound.toString, + "numPartitions" -> numPartitions.toString) + jdbc(url, table, connectionProperties) + } + + /** + * Construct a `DataFrame` representing the database table accessible via JDBC URL url named + * table using connection properties. The `predicates` parameter gives a list expressions + * suitable for inclusion in WHERE clauses; each one defines one partition of the `DataFrame`. + * + * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash + * your external database systems. + * + * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC + * in + * Data Source Option in the version you use. + * + * @param table + * Name of the table in the external database. + * @param predicates + * Condition in the where clause for each partition. + * @param connectionProperties + * JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least + * a "user" and "password" property should be included. "fetchsize" can be used to control the + * number of rows per fetch. + * @since 1.4.0 + */ + def jdbc( + url: String, + table: String, + predicates: Array[String], + connectionProperties: util.Properties): Dataset[Row] + + /** + * Loads a JSON file and returns the results as a `DataFrame`. + * + * See the documentation on the overloaded `json()` method with varargs for more details. + * + * @since 1.4.0 + */ + def json(path: String): Dataset[Row] = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + json(Seq(path): _*) + } + + /** + * Loads JSON files and returns the results as a `DataFrame`. + * + * JSON Lines (newline-delimited JSON) is supported by + * default. For JSON (one record per file), set the `multiLine` option to true. + * + * This function goes through the input once to determine the input schema. If you know the + * schema in advance, use the version that specifies the schema to avoid the extra scan. + * + * You can find the JSON-specific options for reading JSON files in + * Data Source Option in the version you use. + * + * @since 2.0.0 + */ + @scala.annotation.varargs + def json(paths: String*): Dataset[Row] = { + validateJsonSchema() + format("json").load(paths: _*) + } + + /** + * Loads a `Dataset[String]` storing JSON objects (JSON Lines + * text format or newline-delimited JSON) and returns the result as a `DataFrame`. + * + * Unless the schema is specified using `schema` function, this function goes through the input + * once to determine the input schema. + * + * @param jsonDataset + * input Dataset with one JSON object per record + * @since 2.2.0 + */ + def json(jsonDataset: DS[String]): Dataset[Row] + + /** + * Loads a CSV file and returns the result as a `DataFrame`. See the documentation on the other + * overloaded `csv()` method for more details. + * + * @since 2.0.0 + */ + def csv(path: String): Dataset[Row] = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + csv(Seq(path): _*) + } + + /** + * Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`. + * + * If the schema is not specified using `schema` function and `inferSchema` option is enabled, + * this function goes through the input once to determine the input schema. + * + * If the schema is not specified using `schema` function and `inferSchema` option is disabled, + * it determines the columns as string types and it reads only the first line to determine the + * names and the number of fields. + * + * If the enforceSchema is set to `false`, only the CSV header in the first line is checked to + * conform specified or inferred schema. + * + * @note + * if `header` option is set to `true` when calling this API, all lines same with the header + * will be removed if exists. + * + * @param csvDataset + * input Dataset with one CSV row per record + * @since 2.2.0 + */ + def csv(csvDataset: DS[String]): Dataset[Row] + + /** + * Loads CSV files and returns the result as a `DataFrame`. + * + * This function will go through the input once to determine the input schema if `inferSchema` + * is enabled. To avoid going through the entire data once, disable `inferSchema` option or + * specify the schema explicitly using `schema`. + * + * You can find the CSV-specific options for reading CSV files in + * Data Source Option in the version you use. + * + * @since 2.0.0 + */ + @scala.annotation.varargs + def csv(paths: String*): Dataset[Row] = format("csv").load(paths: _*) + + /** + * Loads a XML file and returns the result as a `DataFrame`. See the documentation on the other + * overloaded `xml()` method for more details. + * + * @since 4.0.0 + */ + def xml(path: String): Dataset[Row] = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + xml(Seq(path): _*) + } + + /** + * Loads XML files and returns the result as a `DataFrame`. + * + * This function will go through the input once to determine the input schema if `inferSchema` + * is enabled. To avoid going through the entire data once, disable `inferSchema` option or + * specify the schema explicitly using `schema`. + * + * You can find the XML-specific options for reading XML files in + * Data Source Option in the version you use. + * + * @since 4.0.0 + */ + @scala.annotation.varargs + def xml(paths: String*): Dataset[Row] = { + validateXmlSchema() + format("xml").load(paths: _*) + } + + /** + * Loads an `Dataset[String]` storing XML object and returns the result as a `DataFrame`. + * + * If the schema is not specified using `schema` function and `inferSchema` option is enabled, + * this function goes through the input once to determine the input schema. + * + * @param xmlDataset + * input Dataset with one XML object per record + * @since 4.0.0 + */ + def xml(xmlDataset: DS[String]): Dataset[Row] + + /** + * Loads a Parquet file, returning the result as a `DataFrame`. See the documentation on the + * other overloaded `parquet()` method for more details. + * + * @since 2.0.0 + */ + def parquet(path: String): Dataset[Row] = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + parquet(Seq(path): _*) + } + + /** + * Loads a Parquet file, returning the result as a `DataFrame`. + * + * Parquet-specific option(s) for reading Parquet files can be found in Data + * Source Option in the version you use. + * + * @since 1.4.0 + */ + @scala.annotation.varargs + def parquet(paths: String*): Dataset[Row] = format("parquet").load(paths: _*) + + /** + * Loads an ORC file and returns the result as a `DataFrame`. + * + * @param path + * input path + * @since 1.5.0 + */ + def orc(path: String): Dataset[Row] = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + orc(Seq(path): _*) + } + + /** + * Loads ORC files and returns the result as a `DataFrame`. + * + * ORC-specific option(s) for reading ORC files can be found in Data + * Source Option in the version you use. + * + * @param paths + * input paths + * @since 2.0.0 + */ + @scala.annotation.varargs + def orc(paths: String*): Dataset[Row] = format("orc").load(paths: _*) + + /** + * Returns the specified table/view as a `DataFrame`. If it's a table, it must support batch + * reading and the returned DataFrame is the batch scan query plan of this table. If it's a + * view, the returned DataFrame is simply the query plan of the view, which can either be a + * batch or streaming query plan. + * + * @param tableName + * is either a qualified or unqualified name that designates a table or view. If a database is + * specified, it identifies the table/view from the database. Otherwise, it first attempts to + * find a temporary view with the given name and then match the table/view from the current + * database. Note that, the global temporary view database is also valid here. + * @since 1.4.0 + */ + def table(tableName: String): Dataset[Row] + + /** + * Loads text files and returns a `DataFrame` whose schema starts with a string column named + * "value", and followed by partitioned columns if there are any. See the documentation on the + * other overloaded `text()` method for more details. + * + * @since 2.0.0 + */ + def text(path: String): Dataset[Row] = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + text(Seq(path): _*) + } + + /** + * Loads text files and returns a `DataFrame` whose schema starts with a string column named + * "value", and followed by partitioned columns if there are any. The text files must be encoded + * as UTF-8. + * + * By default, each line in the text files is a new row in the resulting DataFrame. For example: + * {{{ + * // Scala: + * spark.read.text("/path/to/spark/README.md") + * + * // Java: + * spark.read().text("/path/to/spark/README.md") + * }}} + * + * You can find the text-specific options for reading text files in + * Data Source Option in the version you use. + * + * @param paths + * input paths + * @since 1.6.0 + */ + @scala.annotation.varargs + def text(paths: String*): Dataset[Row] = format("text").load(paths: _*) + + /** + * Loads text files and returns a [[Dataset]] of String. See the documentation on the other + * overloaded `textFile()` method for more details. + * @since 2.0.0 + */ + def textFile(path: String): Dataset[String] = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + textFile(Seq(path): _*) + } + + /** + * Loads text files and returns a [[Dataset]] of String. The underlying schema of the Dataset + * contains a single string column named "value". The text files must be encoded as UTF-8. + * + * If the directory structure of the text files contains partitioning information, those are + * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. + * + * By default, each line in the text files is a new row in the resulting DataFrame. For example: + * {{{ + * // Scala: + * spark.read.textFile("/path/to/spark/README.md") + * + * // Java: + * spark.read().textFile("/path/to/spark/README.md") + * }}} + * + * You can set the text-specific options as specified in `DataFrameReader.text`. + * + * @param paths + * input path + * @since 2.0.0 + */ + @scala.annotation.varargs + def textFile(paths: String*): Dataset[String] = { + assertNoSpecifiedSchema("textFile") + text(paths: _*).select("value").as(StringEncoder) + } + + /** + * A convenient function for schema validation in APIs. + */ + protected def assertNoSpecifiedSchema(operation: String): Unit = { + if (userSpecifiedSchema.nonEmpty) { + throw DataTypeErrors.userSpecifiedSchemaUnsupportedError(operation) + } + } + + /** + * Ensure that the `singleVariantColumn` option cannot be used if there is also a user specified + * schema. + */ + protected def validateSingleVariantColumn(): Unit = () + + protected def validateJsonSchema(): Unit = () + + protected def validateXmlSchema(): Unit = () + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + protected var source: String = _ + + protected var userSpecifiedSchema: Option[StructType] = None + + protected var extraOptions: CaseInsensitiveMap[String] = CaseInsensitiveMap[String](Map.empty) +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala index c3ecc7b90d5b4..ae7c256b30ace 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala @@ -34,38 +34,41 @@ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} * @since 1.4.0 */ @Stable -abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { - protected def df: DS[Row] +abstract class DataFrameStatFunctions { + protected def df: Dataset[Row] /** * Calculates the approximate quantiles of a numerical column of a DataFrame. * - * The result of this algorithm has the following deterministic bound: - * If the DataFrame has N elements and if we request the quantile at probability `p` up to error - * `err`, then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank - * of `x` is close to (p * N). - * More precisely, + * The result of this algorithm has the following deterministic bound: If the DataFrame has N + * elements and if we request the quantile at probability `p` up to error `err`, then the + * algorithm will return a sample `x` from the DataFrame so that the *exact* rank of `x` is + * close to (p * N). More precisely, * * {{{ * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N) * }}} * * This method implements a variation of the Greenwald-Khanna algorithm (with some speed - * optimizations). - * The algorithm was first present in - * Space-efficient Online Computation of Quantile Summaries by Greenwald and Khanna. - * - * @param col the name of the numerical column - * @param probabilities a list of quantile probabilities - * Each number must belong to [0, 1]. - * For example 0 is the minimum, 0.5 is the median, 1 is the maximum. - * @param relativeError The relative target precision to achieve (greater than or equal to 0). - * If set to zero, the exact quantiles are computed, which could be very expensive. - * Note that values greater than 1 are accepted but give the same result as 1. - * @return the approximate quantiles at the given probabilities - * - * @note null and NaN values will be removed from the numerical column before calculation. If - * the dataframe is empty or the column only contains null or NaN, an empty array is returned. + * optimizations). The algorithm was first present in Space-efficient Online Computation of Quantile + * Summaries by Greenwald and Khanna. + * + * @param col + * the name of the numerical column + * @param probabilities + * a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the + * minimum, 0.5 is the median, 1 is the maximum. + * @param relativeError + * The relative target precision to achieve (greater than or equal to 0). If set to zero, the + * exact quantiles are computed, which could be very expensive. Note that values greater than + * 1 are accepted but give the same result as 1. + * @return + * the approximate quantiles at the given probabilities + * + * @note + * null and NaN values will be removed from the numerical column before calculation. If the + * dataframe is empty or the column only contains null or NaN, an empty array is returned. * * @since 2.0.0 */ @@ -78,19 +81,24 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { /** * Calculates the approximate quantiles of numerical columns of a DataFrame. - * @see `approxQuantile(col:Str* approxQuantile)` for detailed description. - * - * @param cols the names of the numerical columns - * @param probabilities a list of quantile probabilities - * Each number must belong to [0, 1]. - * For example 0 is the minimum, 0.5 is the median, 1 is the maximum. - * @param relativeError The relative target precision to achieve (greater than or equal to 0). - * If set to zero, the exact quantiles are computed, which could be very expensive. - * Note that values greater than 1 are accepted but give the same result as 1. - * @return the approximate quantiles at the given probabilities of each column - * - * @note null and NaN values will be ignored in numerical columns before calculation. For - * columns only containing null or NaN values, an empty array is returned. + * @see + * `approxQuantile(col:Str* approxQuantile)` for detailed description. + * + * @param cols + * the names of the numerical columns + * @param probabilities + * a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the + * minimum, 0.5 is the median, 1 is the maximum. + * @param relativeError + * The relative target precision to achieve (greater than or equal to 0). If set to zero, the + * exact quantiles are computed, which could be very expensive. Note that values greater than + * 1 are accepted but give the same result as 1. + * @return + * the approximate quantiles at the given probabilities of each column + * + * @note + * null and NaN values will be ignored in numerical columns before calculation. For columns + * only containing null or NaN values, an empty array is returned. * * @since 2.2.0 */ @@ -102,9 +110,12 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { /** * Calculate the sample covariance of two numerical columns of a DataFrame. * - * @param col1 the name of the first column - * @param col2 the name of the second column - * @return the covariance of the two columns. + * @param col1 + * the name of the first column + * @param col2 + * the name of the second column + * @return + * the covariance of the two columns. * * {{{ * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) @@ -121,9 +132,12 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in * MLlib's Statistics. * - * @param col1 the name of the column - * @param col2 the name of the column to calculate the correlation against - * @return The Pearson Correlation Coefficient as a Double. + * @param col1 + * the name of the column + * @param col2 + * the name of the column to calculate the correlation against + * @return + * The Pearson Correlation Coefficient as a Double. * * {{{ * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) @@ -138,9 +152,12 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { /** * Calculates the Pearson Correlation Coefficient of two columns of a DataFrame. * - * @param col1 the name of the column - * @param col2 the name of the column to calculate the correlation against - * @return The Pearson Correlation Coefficient as a Double. + * @param col1 + * the name of the column + * @param col2 + * the name of the column to calculate the correlation against + * @return + * The Pearson Correlation Coefficient as a Double. * * {{{ * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) @@ -159,14 +176,15 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { * The first column of each row will be the distinct values of `col1` and the column names will * be the distinct values of `col2`. The name of the first column will be `col1_col2`. Counts * will be returned as `Long`s. Pairs that have no occurrences will have zero as their counts. - * Null elements will be replaced by "null", and back ticks will be dropped from elements if they - * exist. + * Null elements will be replaced by "null", and back ticks will be dropped from elements if + * they exist. * - * @param col1 The name of the first column. Distinct items will make the first item of - * each row. - * @param col2 The name of the second column. Distinct items will make the column names - * of the DataFrame. - * @return A DataFrame containing for the contingency table. + * @param col1 + * The name of the first column. Distinct items will make the first item of each row. + * @param col2 + * The name of the second column. Distinct items will make the column names of the DataFrame. + * @return + * A DataFrame containing for the contingency table. * * {{{ * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), (3, 3))) @@ -184,22 +202,22 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.4.0 */ - def crosstab(col1: String, col2: String): DS[Row] + def crosstab(col1: String, col2: String): Dataset[Row] /** - * Finding frequent items for columns, possibly with false positives. Using the - * frequent element count algorithm described in - * here, proposed by Karp, - * Schenker, and Papadimitriou. - * The `support` should be greater than 1e-4. + * Finding frequent items for columns, possibly with false positives. Using the frequent element + * count algorithm described in here, + * proposed by Karp, Schenker, and Papadimitriou. The `support` should be greater than 1e-4. * * This function is meant for exploratory data analysis, as we make no guarantee about the * backward compatibility of the schema of the resulting `DataFrame`. * - * @param cols the names of the columns to search frequent items in. - * @param support The minimum frequency for an item to be considered `frequent`. Should be greater - * than 1e-4. - * @return A Local DataFrame with the Array of frequent items for each column. + * @param cols + * the names of the columns to search frequent items in. + * @param support + * The minimum frequency for an item to be considered `frequent`. Should be greater than 1e-4. + * @return + * A Local DataFrame with the Array of frequent items for each column. * * {{{ * val rows = Seq.tabulate(100) { i => @@ -228,36 +246,38 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { * }}} * @since 1.4.0 */ - def freqItems(cols: Array[String], support: Double): DS[Row] = + def freqItems(cols: Array[String], support: Double): Dataset[Row] = freqItems(cols.toImmutableArraySeq, support) /** - * Finding frequent items for columns, possibly with false positives. Using the - * frequent element count algorithm described in - * here, proposed by Karp, - * Schenker, and Papadimitriou. - * Uses a `default` support of 1%. + * Finding frequent items for columns, possibly with false positives. Using the frequent element + * count algorithm described in here, + * proposed by Karp, Schenker, and Papadimitriou. Uses a `default` support of 1%. * * This function is meant for exploratory data analysis, as we make no guarantee about the * backward compatibility of the schema of the resulting `DataFrame`. * - * @param cols the names of the columns to search frequent items in. - * @return A Local DataFrame with the Array of frequent items for each column. + * @param cols + * the names of the columns to search frequent items in. + * @return + * A Local DataFrame with the Array of frequent items for each column. * @since 1.4.0 */ - def freqItems(cols: Array[String]): DS[Row] = freqItems(cols, 0.01) + def freqItems(cols: Array[String]): Dataset[Row] = freqItems(cols, 0.01) /** * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the - * frequent element count algorithm described in - * here, proposed by Karp, Schenker, - * and Papadimitriou. + * frequent element count algorithm described in here, proposed by Karp, Schenker, and + * Papadimitriou. * * This function is meant for exploratory data analysis, as we make no guarantee about the * backward compatibility of the schema of the resulting `DataFrame`. * - * @param cols the names of the columns to search frequent items in. - * @return A Local DataFrame with the Array of frequent items for each column. + * @param cols + * the names of the columns to search frequent items in. + * @return + * A Local DataFrame with the Array of frequent items for each column. * * {{{ * val rows = Seq.tabulate(100) { i => @@ -287,32 +307,38 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.4.0 */ - def freqItems(cols: Seq[String], support: Double): DS[Row] + def freqItems(cols: Seq[String], support: Double): Dataset[Row] /** * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the - * frequent element count algorithm described in - * here, proposed by Karp, Schenker, - * and Papadimitriou. - * Uses a `default` support of 1%. + * frequent element count algorithm described in here, proposed by Karp, Schenker, and + * Papadimitriou. Uses a `default` support of 1%. * * This function is meant for exploratory data analysis, as we make no guarantee about the * backward compatibility of the schema of the resulting `DataFrame`. * - * @param cols the names of the columns to search frequent items in. - * @return A Local DataFrame with the Array of frequent items for each column. + * @param cols + * the names of the columns to search frequent items in. + * @return + * A Local DataFrame with the Array of frequent items for each column. * @since 1.4.0 */ - def freqItems(cols: Seq[String]): DS[Row] = freqItems(cols, 0.01) + def freqItems(cols: Seq[String]): Dataset[Row] = freqItems(cols, 0.01) /** * Returns a stratified sample without replacement based on the fraction given on each stratum. - * @param col column that defines strata - * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat - * its fraction as zero. - * @param seed random seed - * @tparam T stratum type - * @return a new `DataFrame` that represents the stratified sample + * @param col + * column that defines strata + * @param fractions + * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as + * zero. + * @param seed + * random seed + * @tparam T + * stratum type + * @return + * a new `DataFrame` that represents the stratified sample * * {{{ * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), @@ -330,33 +356,43 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.5.0 */ - def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DS[Row] = { + def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): Dataset[Row] = { sampleBy(Column(col), fractions, seed) } /** * Returns a stratified sample without replacement based on the fraction given on each stratum. - * @param col column that defines strata - * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat - * its fraction as zero. - * @param seed random seed - * @tparam T stratum type - * @return a new `DataFrame` that represents the stratified sample + * @param col + * column that defines strata + * @param fractions + * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as + * zero. + * @param seed + * random seed + * @tparam T + * stratum type + * @return + * a new `DataFrame` that represents the stratified sample * * @since 1.5.0 */ - def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DS[Row] = { + def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): Dataset[Row] = { sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) } /** * Returns a stratified sample without replacement based on the fraction given on each stratum. - * @param col column that defines strata - * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat - * its fraction as zero. - * @param seed random seed - * @tparam T stratum type - * @return a new `DataFrame` that represents the stratified sample + * @param col + * column that defines strata + * @param fractions + * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as + * zero. + * @param seed + * random seed + * @tparam T + * stratum type + * @return + * a new `DataFrame` that represents the stratified sample * * The stratified sample can be performed over multiple columns: * {{{ @@ -377,33 +413,42 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { * * @since 3.0.0 */ - def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DS[Row] - + def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): Dataset[Row] /** * (Java-specific) Returns a stratified sample without replacement based on the fraction given * on each stratum. * - * @param col column that defines strata - * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat - * its fraction as zero. - * @param seed random seed - * @tparam T stratum type - * @return a new `DataFrame` that represents the stratified sample + * @param col + * column that defines strata + * @param fractions + * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as + * zero. + * @param seed + * random seed + * @tparam T + * stratum type + * @return + * a new `DataFrame` that represents the stratified sample * @since 3.0.0 */ - def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DS[Row] = { + def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): Dataset[Row] = { sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) } /** * Builds a Count-min Sketch over a specified column. * - * @param colName name of the column over which the sketch is built - * @param depth depth of the sketch - * @param width width of the sketch - * @param seed random seed - * @return a `CountMinSketch` over column `colName` + * @param colName + * name of the column over which the sketch is built + * @param depth + * depth of the sketch + * @param width + * width of the sketch + * @param seed + * random seed + * @return + * a `CountMinSketch` over column `colName` * @since 2.0.0 */ def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = { @@ -413,26 +458,39 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { /** * Builds a Count-min Sketch over a specified column. * - * @param colName name of the column over which the sketch is built - * @param eps relative error of the sketch - * @param confidence confidence of the sketch - * @param seed random seed - * @return a `CountMinSketch` over column `colName` + * @param colName + * name of the column over which the sketch is built + * @param eps + * relative error of the sketch + * @param confidence + * confidence of the sketch + * @param seed + * random seed + * @return + * a `CountMinSketch` over column `colName` * @since 2.0.0 */ def countMinSketch( - colName: String, eps: Double, confidence: Double, seed: Int): CountMinSketch = { + colName: String, + eps: Double, + confidence: Double, + seed: Int): CountMinSketch = { countMinSketch(Column(colName), eps, confidence, seed) } /** * Builds a Count-min Sketch over a specified column. * - * @param col the column over which the sketch is built - * @param depth depth of the sketch - * @param width width of the sketch - * @param seed random seed - * @return a `CountMinSketch` over column `colName` + * @param col + * the column over which the sketch is built + * @param depth + * depth of the sketch + * @param width + * width of the sketch + * @param seed + * random seed + * @return + * a `CountMinSketch` over column `colName` * @since 2.0.0 */ def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = { @@ -444,29 +502,34 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { /** * Builds a Count-min Sketch over a specified column. * - * @param col the column over which the sketch is built - * @param eps relative error of the sketch - * @param confidence confidence of the sketch - * @param seed random seed - * @return a `CountMinSketch` over column `colName` + * @param col + * the column over which the sketch is built + * @param eps + * relative error of the sketch + * @param confidence + * confidence of the sketch + * @param seed + * random seed + * @return + * a `CountMinSketch` over column `colName` * @since 2.0.0 */ - def countMinSketch( - col: Column, - eps: Double, - confidence: Double, - seed: Int): CountMinSketch = withOrigin { - val cms = count_min_sketch(col, lit(eps), lit(confidence), lit(seed)) - val bytes: Array[Byte] = df.select(cms).as(BinaryEncoder).head() - CountMinSketch.readFrom(bytes) - } + def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = + withOrigin { + val cms = count_min_sketch(col, lit(eps), lit(confidence), lit(seed)) + val bytes: Array[Byte] = df.select(cms).as(BinaryEncoder).head() + CountMinSketch.readFrom(bytes) + } /** * Builds a Bloom filter over a specified column. * - * @param colName name of the column over which the filter is built - * @param expectedNumItems expected number of items which will be put into the filter. - * @param fpp expected false positive probability of the filter. + * @param colName + * name of the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param fpp + * expected false positive probability of the filter. * @since 2.0.0 */ def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { @@ -476,9 +539,12 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { /** * Builds a Bloom filter over a specified column. * - * @param col the column over which the filter is built - * @param expectedNumItems expected number of items which will be put into the filter. - * @param fpp expected false positive probability of the filter. + * @param col + * the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param fpp + * expected false positive probability of the filter. * @since 2.0.0 */ def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { @@ -489,9 +555,12 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { /** * Builds a Bloom filter over a specified column. * - * @param colName name of the column over which the filter is built - * @param expectedNumItems expected number of items which will be put into the filter. - * @param numBits expected number of bits of the filter. + * @param colName + * name of the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param numBits + * expected number of bits of the filter. * @since 2.0.0 */ def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { @@ -501,9 +570,12 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { /** * Builds a Bloom filter over a specified column. * - * @param col the column over which the filter is built - * @param expectedNumItems expected number of items which will be put into the filter. - * @param numBits expected number of bits of the filter. + * @param col + * the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param numBits + * expected number of bits of the filter. * @since 2.0.0 */ def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = withOrigin { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala index 49f77a1a61204..6eef034aa5157 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala @@ -23,15 +23,16 @@ import _root_.java.util import org.apache.spark.annotation.{DeveloperApi, Stable} import org.apache.spark.api.java.function.{FilterFunction, FlatMapFunction, ForeachFunction, ForeachPartitionFunction, MapFunction, MapPartitionsFunction, ReduceFunction} -import org.apache.spark.sql.{functions, AnalysisException, Column, DataFrameWriter, Encoder, Observation, Row, TypedColumn} +import org.apache.spark.sql.{functions, AnalysisException, Column, DataFrameWriter, DataFrameWriterV2, Encoder, MergeIntoWriter, Observation, Row, TypedColumn} +import org.apache.spark.sql.internal.{ToScalaUDF, UDFAdaptors} import org.apache.spark.sql.types.{Metadata, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SparkClassUtils /** - * A Dataset is a strongly typed collection of domain-specific objects that can be transformed - * in parallel using functional or relational operations. Each Dataset also has an untyped view + * A Dataset is a strongly typed collection of domain-specific objects that can be transformed in + * parallel using functional or relational operations. Each Dataset also has an untyped view * called a `DataFrame`, which is a Dataset of [[org.apache.spark.sql.Row]]. * * Operations available on Datasets are divided into transformations and actions. Transformations @@ -39,29 +40,29 @@ import org.apache.spark.util.SparkClassUtils * return results. Example transformations include map, filter, select, and aggregate (`groupBy`). * Example actions count, show, or writing data out to file systems. * - * Datasets are "lazy", i.e. computations are only triggered when an action is invoked. Internally, - * a Dataset represents a logical plan that describes the computation required to produce the data. - * When an action is invoked, Spark's query optimizer optimizes the logical plan and generates a - * physical plan for efficient execution in a parallel and distributed manner. To explore the - * logical plan as well as optimized physical plan, use the `explain` function. + * Datasets are "lazy", i.e. computations are only triggered when an action is invoked. + * Internally, a Dataset represents a logical plan that describes the computation required to + * produce the data. When an action is invoked, Spark's query optimizer optimizes the logical plan + * and generates a physical plan for efficient execution in a parallel and distributed manner. To + * explore the logical plan as well as optimized physical plan, use the `explain` function. * - * To efficiently support domain-specific objects, an [[org.apache.spark.sql.Encoder]] is required. - * The encoder maps the domain specific type `T` to Spark's internal type system. For example, given - * a class `Person` with two fields, `name` (string) and `age` (int), an encoder is used to tell - * Spark to generate code at runtime to serialize the `Person` object into a binary structure. This - * binary structure often has much lower memory footprint as well as are optimized for efficiency - * in data processing (e.g. in a columnar format). To understand the internal binary representation - * for data, use the `schema` function. + * To efficiently support domain-specific objects, an [[org.apache.spark.sql.Encoder]] is + * required. The encoder maps the domain specific type `T` to Spark's internal type system. For + * example, given a class `Person` with two fields, `name` (string) and `age` (int), an encoder is + * used to tell Spark to generate code at runtime to serialize the `Person` object into a binary + * structure. This binary structure often has much lower memory footprint as well as are optimized + * for efficiency in data processing (e.g. in a columnar format). To understand the internal + * binary representation for data, use the `schema` function. * - * There are typically two ways to create a Dataset. The most common way is by pointing Spark - * to some files on storage systems, using the `read` function available on a `SparkSession`. + * There are typically two ways to create a Dataset. The most common way is by pointing Spark to + * some files on storage systems, using the `read` function available on a `SparkSession`. * {{{ * val people = spark.read.parquet("...").as[Person] // Scala * Dataset people = spark.read().parquet("...").as(Encoders.bean(Person.class)); // Java * }}} * - * Datasets can also be created through transformations available on existing Datasets. For example, - * the following creates a new Dataset by applying a filter on the existing one: + * Datasets can also be created through transformations available on existing Datasets. For + * example, the following creates a new Dataset by applying a filter on the existing one: * {{{ * val names = people.map(_.name) // in Scala; names is a Dataset[String] * Dataset names = people.map( @@ -70,8 +71,8 @@ import org.apache.spark.util.SparkClassUtils * * Dataset operations can also be untyped, through various domain-specific-language (DSL) * functions defined in: Dataset (this class), [[org.apache.spark.sql.Column]], and - * [[org.apache.spark.sql.functions]]. These operations are very similar to the operations available - * in the data frame abstraction in R or Python. + * [[org.apache.spark.sql.functions]]. These operations are very similar to the operations + * available in the data frame abstraction in R or Python. * * To select a column from the Dataset, use `apply` method in Scala and `col` in Java. * {{{ @@ -118,10 +119,10 @@ import org.apache.spark.util.SparkClassUtils * @since 1.6.0 */ @Stable -abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { - type RGD <: RelationalGroupedDataset[DS] +abstract class Dataset[T] extends Serializable { + type DS[U] <: Dataset[U] - def sparkSession: SparkSession[DS] + def sparkSession: SparkSession val encoder: Encoder[T] @@ -135,52 +136,46 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { */ // This is declared with parentheses to prevent the Scala compiler from treating // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DS[Row] - - /** - * Returns a new Dataset where each record has been mapped on to the specified type. The - * method used to map columns depend on the type of `U`: - *
    - *
  • When `U` is a class, fields for the class will be mapped to columns of the same name - * (case sensitivity is determined by `spark.sql.caseSensitive`).
  • - *
  • When `U` is a tuple, the columns will be mapped by ordinal (i.e. the first column will - * be assigned to `_1`).
  • - *
  • When `U` is a primitive type (i.e. String, Int, etc), then the first column of the - * `DataFrame` will be used.
  • + def toDF(): Dataset[Row] + + /** + * Returns a new Dataset where each record has been mapped on to the specified type. The method + * used to map columns depend on the type of `U`:
    • When `U` is a class, fields for the + * class will be mapped to columns of the same name (case sensitivity is determined by + * `spark.sql.caseSensitive`).
    • When `U` is a tuple, the columns will be mapped by + * ordinal (i.e. the first column will be assigned to `_1`).
    • When `U` is a primitive + * type (i.e. String, Int, etc), then the first column of the `DataFrame` will be used.
    • *
    * - * If the schema of the Dataset does not match the desired `U` type, you can use `select` - * along with `alias` or `as` to rearrange or rename as required. + * If the schema of the Dataset does not match the desired `U` type, you can use `select` along + * with `alias` or `as` to rearrange or rename as required. * - * Note that `as[]` only changes the view of the data that is passed into typed operations, - * such as `map()`, and does not eagerly project away any columns that are not present in - * the specified class. + * Note that `as[]` only changes the view of the data that is passed into typed operations, such + * as `map()`, and does not eagerly project away any columns that are not present in the + * specified class. * * @group basic * @since 1.6.0 */ - def as[U: Encoder]: DS[U] + def as[U: Encoder]: Dataset[U] /** - * Returns a new DataFrame where each row is reconciled to match the specified schema. Spark will: - *
      - *
    • Reorder columns and/or inner fields by name to match the specified schema.
    • - *
    • Project away columns and/or inner fields that are not needed by the specified schema. - * Missing columns and/or inner fields (present in the specified schema but not input DataFrame) - * lead to failures.
    • - *
    • Cast the columns and/or inner fields to match the data types in the specified schema, if - * the types are compatible, e.g., numeric to numeric (error if overflows), but not string to - * int.
    • - *
    • Carry over the metadata from the specified schema, while the columns and/or inner fields - * still keep their own metadata if not overwritten by the specified schema.
    • - *
    • Fail if the nullability is not compatible. For example, the column and/or inner field is - * nullable but the specified schema requires them to be not nullable.
    • - *
    + * Returns a new DataFrame where each row is reconciled to match the specified schema. Spark + * will:
    • Reorder columns and/or inner fields by name to match the specified + * schema.
    • Project away columns and/or inner fields that are not needed by the + * specified schema. Missing columns and/or inner fields (present in the specified schema but + * not input DataFrame) lead to failures.
    • Cast the columns and/or inner fields to match + * the data types in the specified schema, if the types are compatible, e.g., numeric to numeric + * (error if overflows), but not string to int.
    • Carry over the metadata from the + * specified schema, while the columns and/or inner fields still keep their own metadata if not + * overwritten by the specified schema.
    • Fail if the nullability is not compatible. For + * example, the column and/or inner field is nullable but the specified schema requires them to + * be not nullable.
    * * @group basic * @since 3.4.0 */ - def to(schema: StructType): DS[Row] + def to(schema: StructType): Dataset[Row] /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. @@ -196,7 +191,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def toDF(colNames: String*): DS[Row] + def toDF(colNames: String*): Dataset[Row] /** * Returns the schema of this Dataset. @@ -228,16 +223,12 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { /** * Prints the plans (logical and physical) with a format specified by a given explain mode. * - * @param mode specifies the expected output format of plans. - *
      - *
    • `simple` Print only a physical plan.
    • - *
    • `extended`: Print both logical and physical plans.
    • - *
    • `codegen`: Print a physical plan and generated codes if they are - * available.
    • - *
    • `cost`: Print a logical plan and statistics if they are available.
    • - *
    • `formatted`: Split explain output into two sections: a physical plan outline - * and node details.
    • - *
    + * @param mode + * specifies the expected output format of plans.
    • `simple` Print only a physical + * plan.
    • `extended`: Print both logical and physical plans.
    • `codegen`: Print + * a physical plan and generated codes if they are available.
    • `cost`: Print a logical + * plan and statistics if they are available.
    • `formatted`: Split explain output into + * two sections: a physical plan outline and node details.
    * @group basic * @since 3.0.0 */ @@ -246,12 +237,12 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { /** * Prints the plans (logical and physical) to the console for debugging purposes. * - * @param extended default `false`. If `false`, prints only the physical plan. + * @param extended + * default `false`. If `false`, prints only the physical plan. * @group basic * @since 1.6.0 */ def explain(extended: Boolean): Unit = if (extended) { - // TODO move ExplainMode? explain("extended") } else { explain("simple") @@ -284,8 +275,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { def columns: Array[String] = schema.fields.map(_.name) /** - * Returns true if the `collect` and `take` methods can be run locally - * (without any Spark executors). + * Returns true if the `collect` and `take` methods can be run locally (without any Spark + * executors). * * @group basic * @since 1.6.0 @@ -301,12 +292,11 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { def isEmpty: Boolean /** - * Returns true if this Dataset contains one or more sources that continuously - * return data as it arrives. A Dataset that reads data from a streaming source - * must be executed as a `StreamingQuery` using the `start()` method in - * `DataStreamWriter`. Methods that return a single answer, e.g. `count()` or - * `collect()`, will throw an [[org.apache.spark.sql.AnalysisException]] when there is a - * streaming source present. + * Returns true if this Dataset contains one or more sources that continuously return data as it + * arrives. A Dataset that reads data from a streaming source must be executed as a + * `StreamingQuery` using the `start()` method in `DataStreamWriter`. Methods that return a + * single answer, e.g. `count()` or `collect()`, will throw an + * [[org.apache.spark.sql.AnalysisException]] when there is a streaming source present. * * @group streaming * @since 2.0.0 @@ -314,103 +304,106 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { def isStreaming: Boolean /** - * Eagerly checkpoint a Dataset and return the new Dataset. Checkpointing can be used to truncate - * the logical plan of this Dataset, which is especially useful in iterative algorithms where the - * plan may grow exponentially. It will be saved to files inside the checkpoint + * Eagerly checkpoint a Dataset and return the new Dataset. Checkpointing can be used to + * truncate the logical plan of this Dataset, which is especially useful in iterative algorithms + * where the plan may grow exponentially. It will be saved to files inside the checkpoint * directory set with `SparkContext#setCheckpointDir`. * * @group basic * @since 2.1.0 */ - def checkpoint(): DS[T] = checkpoint(eager = true, reliableCheckpoint = true) + def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true) /** * Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the * logical plan of this Dataset, which is especially useful in iterative algorithms where the - * plan may grow exponentially. It will be saved to files inside the checkpoint - * directory set with `SparkContext#setCheckpointDir`. - * - * @param eager Whether to checkpoint this dataframe immediately - * @note When checkpoint is used with eager = false, the final data that is checkpointed after - * the first action may be different from the data that was used during the job due to - * non-determinism of the underlying operation and retries. If checkpoint is used to achieve - * saving a deterministic snapshot of the data, eager = true should be used. Otherwise, - * it is only deterministic after the first execution, after the checkpoint was finalized. + * plan may grow exponentially. It will be saved to files inside the checkpoint directory set + * with `SparkContext#setCheckpointDir`. + * + * @param eager + * Whether to checkpoint this dataframe immediately + * @note + * When checkpoint is used with eager = false, the final data that is checkpointed after the + * first action may be different from the data that was used during the job due to + * non-determinism of the underlying operation and retries. If checkpoint is used to achieve + * saving a deterministic snapshot of the data, eager = true should be used. Otherwise, it is + * only deterministic after the first execution, after the checkpoint was finalized. * @group basic * @since 2.1.0 */ - def checkpoint(eager: Boolean): DS[T] = checkpoint(eager = eager, reliableCheckpoint = true) + def checkpoint(eager: Boolean): Dataset[T] = + checkpoint(eager = eager, reliableCheckpoint = true) /** - * Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be - * used to truncate the logical plan of this Dataset, which is especially useful in iterative + * Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used + * to truncate the logical plan of this Dataset, which is especially useful in iterative * algorithms where the plan may grow exponentially. Local checkpoints are written to executor * storage and despite potentially faster they are unreliable and may compromise job completion. * * @group basic * @since 2.3.0 */ - def localCheckpoint(): DS[T] = checkpoint(eager = true, reliableCheckpoint = false) + def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false) /** - * Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to truncate - * the logical plan of this Dataset, which is especially useful in iterative algorithms where the - * plan may grow exponentially. Local checkpoints are written to executor storage and despite - * potentially faster they are unreliable and may compromise job completion. + * Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to + * truncate the logical plan of this Dataset, which is especially useful in iterative algorithms + * where the plan may grow exponentially. Local checkpoints are written to executor storage and + * despite potentially faster they are unreliable and may compromise job completion. * - * @param eager Whether to checkpoint this dataframe immediately - * @note When checkpoint is used with eager = false, the final data that is checkpointed after - * the first action may be different from the data that was used during the job due to - * non-determinism of the underlying operation and retries. If checkpoint is used to achieve - * saving a deterministic snapshot of the data, eager = true should be used. Otherwise, - * it is only deterministic after the first execution, after the checkpoint was finalized. + * @param eager + * Whether to checkpoint this dataframe immediately + * @note + * When checkpoint is used with eager = false, the final data that is checkpointed after the + * first action may be different from the data that was used during the job due to + * non-determinism of the underlying operation and retries. If checkpoint is used to achieve + * saving a deterministic snapshot of the data, eager = true should be used. Otherwise, it is + * only deterministic after the first execution, after the checkpoint was finalized. * @group basic * @since 2.3.0 */ - def localCheckpoint(eager: Boolean): DS[T] = checkpoint( - eager = eager, - reliableCheckpoint = false - ) + def localCheckpoint(eager: Boolean): Dataset[T] = + checkpoint(eager = eager, reliableCheckpoint = false) /** * Returns a checkpointed version of this Dataset. * - * @param eager Whether to checkpoint this dataframe immediately - * @param reliableCheckpoint Whether to create a reliable checkpoint saved to files inside the - * checkpoint directory. If false creates a local checkpoint using - * the caching subsystem + * @param eager + * Whether to checkpoint this dataframe immediately + * @param reliableCheckpoint + * Whether to create a reliable checkpoint saved to files inside the checkpoint directory. If + * false creates a local checkpoint using the caching subsystem */ - protected def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): DS[T] + protected def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] /** * Defines an event time watermark for this [[Dataset]]. A watermark tracks a point in time * before which we assume no more late data is going to arrive. * - * Spark will use this watermark for several purposes: - *
      - *
    • To know when a given time window aggregation can be finalized and thus can be emitted - * when using output modes that do not allow updates.
    • - *
    • To minimize the amount of state that we need to keep for on-going aggregations, - * `mapGroupsWithState` and `dropDuplicates` operators.
    • - *
    - * The current watermark is computed by looking at the `MAX(eventTime)` seen across - * all of the partitions in the query minus a user specified `delayThreshold`. Due to the cost - * of coordinating this value across partitions, the actual watermark used is only guaranteed - * to be at least `delayThreshold` behind the actual event time. In some cases we may still - * process records that arrive more than `delayThreshold` late. - * - * @param eventTime the name of the column that contains the event time of the row. - * @param delayThreshold the minimum delay to wait to data to arrive late, relative to the latest - * record that has been processed in the form of an interval - * (e.g. "1 minute" or "5 hours"). NOTE: This should not be negative. + * Spark will use this watermark for several purposes:
    • To know when a given time window + * aggregation can be finalized and thus can be emitted when using output modes that do not + * allow updates.
    • To minimize the amount of state that we need to keep for on-going + * aggregations, `mapGroupsWithState` and `dropDuplicates` operators.
    The current + * watermark is computed by looking at the `MAX(eventTime)` seen across all of the partitions in + * the query minus a user specified `delayThreshold`. Due to the cost of coordinating this value + * across partitions, the actual watermark used is only guaranteed to be at least + * `delayThreshold` behind the actual event time. In some cases we may still process records + * that arrive more than `delayThreshold` late. + * + * @param eventTime + * the name of the column that contains the event time of the row. + * @param delayThreshold + * the minimum delay to wait to data to arrive late, relative to the latest record that has + * been processed in the form of an interval (e.g. "1 minute" or "5 hours"). NOTE: This should + * not be negative. * @group streaming * @since 2.1.0 */ // We only accept an existing column name, not a derived column here as a watermark that is // defined on a derived column cannot referenced elsewhere in the plan. - def withWatermark(eventTime: String, delayThreshold: String): DS[T] + def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] - /** + /** * Displays the Dataset in a tabular form. Strings more than 20 characters will be truncated, * and all cells will be aligned right. For example: * {{{ @@ -422,7 +415,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * 1984 04 0.450090 0.483521 * }}} * - * @param numRows Number of rows to show + * @param numRows + * Number of rows to show * * @group action * @since 1.6.0 @@ -430,8 +424,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { def show(numRows: Int): Unit = show(numRows, truncate = true) /** - * Displays the top 20 rows of Dataset in a tabular form. Strings more than 20 characters - * will be truncated, and all cells will be aligned right. + * Displays the top 20 rows of Dataset in a tabular form. Strings more than 20 characters will + * be truncated, and all cells will be aligned right. * * @group action * @since 1.6.0 @@ -441,8 +435,9 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { /** * Displays the top 20 rows of Dataset in a tabular form. * - * @param truncate Whether truncate long strings. If true, strings more than 20 characters will - * be truncated and all cells will be aligned right + * @param truncate + * Whether truncate long strings. If true, strings more than 20 characters will be truncated + * and all cells will be aligned right * * @group action * @since 1.6.0 @@ -459,9 +454,11 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * 1983 03 0.410516 0.442194 * 1984 04 0.450090 0.483521 * }}} - * @param numRows Number of rows to show - * @param truncate Whether truncate long strings. If true, strings more than 20 characters will - * be truncated and all cells will be aligned right + * @param numRows + * Number of rows to show + * @param truncate + * Whether truncate long strings. If true, strings more than 20 characters will be truncated + * and all cells will be aligned right * * @group action * @since 1.6.0 @@ -480,9 +477,11 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * 1984 04 0.450090 0.483521 * }}} * - * @param numRows Number of rows to show - * @param truncate If set to more than 0, truncates strings to `truncate` characters and - * all cells will be aligned right. + * @param numRows + * Number of rows to show + * @param truncate + * If set to more than 0, truncates strings to `truncate` characters and all cells will be + * aligned right. * @group action * @since 1.6.0 */ @@ -499,7 +498,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * 1984 04 0.450090 0.483521 * }}} * - * If `vertical` enabled, this command prints output rows vertically (one line per column value)? + * If `vertical` enabled, this command prints output rows vertically (one line per column + * value)? * * {{{ * -RECORD 0------------------- @@ -529,10 +529,13 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * AVG('Adj Close) | 0.483521 * }}} * - * @param numRows Number of rows to show - * @param truncate If set to more than 0, truncates strings to `truncate` characters and - * all cells will be aligned right. - * @param vertical If set to true, prints output rows vertically (one line per column value). + * @param numRows + * Number of rows to show + * @param truncate + * If set to more than 0, truncates strings to `truncate` characters and all cells will be + * aligned right. + * @param vertical + * If set to true, prints output rows vertically (one line per column value). * @group action * @since 2.3.0 */ @@ -549,7 +552,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 1.6.0 */ - def na: DataFrameNaFunctions[DS] + def na: DataFrameNaFunctions /** * Returns a [[DataFrameStatFunctions]] for working statistic functions support. @@ -561,20 +564,21 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 1.6.0 */ - def stat: DataFrameStatFunctions[DS] + def stat: DataFrameStatFunctions /** * Join with another `DataFrame`. * * Behaves as an INNER JOIN and requires a subsequent join predicate. * - * @param right Right side of the join operation. + * @param right + * Right side of the join operation. * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_]): DS[Row] + def join(right: DS[_]): Dataset[Row] - /** + /** * Inner equi-join with another `DataFrame` using the given column. * * Different from other join functions, the join column will only appear once in the output, @@ -585,17 +589,20 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * df1.join(df2, "user_id") * }}} * - * @param right Right side of the join operation. - * @param usingColumn Name of the column to join on. This column must exist on both sides. + * @param right + * Right side of the join operation. + * @param usingColumn + * Name of the column to join on. This column must exist on both sides. * - * @note If you perform a self-join using this function without aliasing the input - * `DataFrame`s, you will NOT be able to reference any columns after the join, since - * there is no way to disambiguate which side of the join you would like to reference. + * @note + * If you perform a self-join using this function without aliasing the input `DataFrame`s, you + * will NOT be able to reference any columns after the join, since there is no way to + * disambiguate which side of the join you would like to reference. * * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], usingColumn: String): DS[Row] = { + def join(right: DS[_], usingColumn: String): Dataset[Row] = { join(right, Seq(usingColumn)) } @@ -603,13 +610,15 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * (Java-specific) Inner equi-join with another `DataFrame` using the given columns. See the * Scala-specific overload for more details. * - * @param right Right side of the join operation. - * @param usingColumns Names of the columns to join on. This columns must exist on both sides. + * @param right + * Right side of the join operation. + * @param usingColumns + * Names of the columns to join on. This columns must exist on both sides. * * @group untypedrel * @since 3.4.0 */ - def join(right: DS[_], usingColumns: Array[String]): DS[Row] = { + def join(right: DS[_], usingColumns: Array[String]): Dataset[Row] = { join(right, usingColumns.toImmutableArraySeq) } @@ -624,43 +633,50 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * df1.join(df2, Seq("user_id", "user_name")) * }}} * - * @param right Right side of the join operation. - * @param usingColumns Names of the columns to join on. This columns must exist on both sides. + * @param right + * Right side of the join operation. + * @param usingColumns + * Names of the columns to join on. This columns must exist on both sides. * - * @note If you perform a self-join using this function without aliasing the input - * `DataFrame`s, you will NOT be able to reference any columns after the join, since - * there is no way to disambiguate which side of the join you would like to reference. + * @note + * If you perform a self-join using this function without aliasing the input `DataFrame`s, you + * will NOT be able to reference any columns after the join, since there is no way to + * disambiguate which side of the join you would like to reference. * * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], usingColumns: Seq[String]): DS[Row] = { + def join(right: DS[_], usingColumns: Seq[String]): Dataset[Row] = { join(right, usingColumns, "inner") } /** - * Equi-join with another `DataFrame` using the given column. A cross join with a predicate - * is specified as an inner join. If you would explicitly like to perform a cross join use the + * Equi-join with another `DataFrame` using the given column. A cross join with a predicate is + * specified as an inner join. If you would explicitly like to perform a cross join use the * `crossJoin` method. * * Different from other join functions, the join column will only appear once in the output, * i.e. similar to SQL's `JOIN USING` syntax. * - * @param right Right side of the join operation. - * @param usingColumn Name of the column to join on. This column must exist on both sides. - * @param joinType Type of join to perform. Default `inner`. Must be one of: - * `inner`, `cross`, `outer`, `full`, `fullouter`, `full_outer`, `left`, - * `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`, - * `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, `left_anti`. - * - * @note If you perform a self-join using this function without aliasing the input - * `DataFrame`s, you will NOT be able to reference any columns after the join, since - * there is no way to disambiguate which side of the join you would like to reference. + * @param right + * Right side of the join operation. + * @param usingColumn + * Name of the column to join on. This column must exist on both sides. + * @param joinType + * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`, + * `full`, `fullouter`, `full_outer`, `left`, `leftouter`, `left_outer`, `right`, + * `rightouter`, `right_outer`, `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, + * `left_anti`. + * + * @note + * If you perform a self-join using this function without aliasing the input `DataFrame`s, you + * will NOT be able to reference any columns after the join, since there is no way to + * disambiguate which side of the join you would like to reference. * * @group untypedrel * @since 3.4.0 */ - def join(right: DS[_], usingColumn: String, joinType: String): DS[Row] = { + def join(right: DS[_], usingColumn: String, joinType: String): Dataset[Row] = { join(right, Seq(usingColumn), joinType) } @@ -668,17 +684,20 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * (Java-specific) Equi-join with another `DataFrame` using the given columns. See the * Scala-specific overload for more details. * - * @param right Right side of the join operation. - * @param usingColumns Names of the columns to join on. This columns must exist on both sides. - * @param joinType Type of join to perform. Default `inner`. Must be one of: - * `inner`, `cross`, `outer`, `full`, `fullouter`, `full_outer`, `left`, - * `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`, - * `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, `left_anti`. + * @param right + * Right side of the join operation. + * @param usingColumns + * Names of the columns to join on. This columns must exist on both sides. + * @param joinType + * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`, + * `full`, `fullouter`, `full_outer`, `left`, `leftouter`, `left_outer`, `right`, + * `rightouter`, `right_outer`, `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, + * `left_anti`. * * @group untypedrel * @since 3.4.0 */ - def join(right: DS[_], usingColumns: Array[String], joinType: String): DS[Row] = { + def join(right: DS[_], usingColumns: Array[String], joinType: String): Dataset[Row] = { join(right, usingColumns.toImmutableArraySeq, joinType) } @@ -690,21 +709,25 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * Different from other join functions, the join columns will only appear once in the output, * i.e. similar to SQL's `JOIN USING` syntax. * - * @param right Right side of the join operation. - * @param usingColumns Names of the columns to join on. This columns must exist on both sides. - * @param joinType Type of join to perform. Default `inner`. Must be one of: - * `inner`, `cross`, `outer`, `full`, `fullouter`, `full_outer`, `left`, - * `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`, - * `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, `left_anti`. - * - * @note If you perform a self-join using this function without aliasing the input - * `DataFrame`s, you will NOT be able to reference any columns after the join, since - * there is no way to disambiguate which side of the join you would like to reference. + * @param right + * Right side of the join operation. + * @param usingColumns + * Names of the columns to join on. This columns must exist on both sides. + * @param joinType + * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`, + * `full`, `fullouter`, `full_outer`, `left`, `leftouter`, `left_outer`, `right`, + * `rightouter`, `right_outer`, `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, + * `left_anti`. + * + * @note + * If you perform a self-join using this function without aliasing the input `DataFrame`s, you + * will NOT be able to reference any columns after the join, since there is no way to + * disambiguate which side of the join you would like to reference. * * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], usingColumns: Seq[String], joinType: String): DS[Row] + def join(right: DS[_], usingColumns: Seq[String], joinType: String): Dataset[Row] /** * Inner join with another `DataFrame`, using the given join expression. @@ -718,12 +741,12 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], joinExprs: Column): DS[Row] = + def join(right: DS[_], joinExprs: Column): Dataset[Row] = join(right, joinExprs, "inner") /** - * Join with another `DataFrame`, using the given join expression. The following performs - * a full outer join between `df1` and `df2`. + * Join with another `DataFrame`, using the given join expression. The following performs a full + * outer join between `df1` and `df2`. * * {{{ * // Scala: @@ -735,64 +758,73 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * df1.join(df2, col("df1Key").equalTo(col("df2Key")), "outer"); * }}} * - * @param right Right side of the join. - * @param joinExprs Join expression. - * @param joinType Type of join to perform. Default `inner`. Must be one of: - * `inner`, `cross`, `outer`, `full`, `fullouter`, `full_outer`, `left`, - * `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`, - * `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, `left_anti`. + * @param right + * Right side of the join. + * @param joinExprs + * Join expression. + * @param joinType + * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`, + * `full`, `fullouter`, `full_outer`, `left`, `leftouter`, `left_outer`, `right`, + * `rightouter`, `right_outer`, `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, + * `left_anti`. * * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], joinExprs: Column, joinType: String): DS[Row] + def join(right: DS[_], joinExprs: Column, joinType: String): Dataset[Row] /** * Explicit cartesian join with another `DataFrame`. * - * @param right Right side of the join operation. - * @note Cartesian joins are very expensive without an extra filter that can be pushed down. + * @param right + * Right side of the join operation. + * @note + * Cartesian joins are very expensive without an extra filter that can be pushed down. * @group untypedrel * @since 2.1.0 */ - def crossJoin(right: DS[_]): DS[Row] + def crossJoin(right: DS[_]): Dataset[Row] /** - * Joins this Dataset returning a `Tuple2` for each pair where `condition` evaluates to - * true. + * Joins this Dataset returning a `Tuple2` for each pair where `condition` evaluates to true. * - * This is similar to the relation `join` function with one important difference in the - * result schema. Since `joinWith` preserves objects present on either side of the join, the - * result schema is similarly nested into a tuple under the column names `_1` and `_2`. + * This is similar to the relation `join` function with one important difference in the result + * schema. Since `joinWith` preserves objects present on either side of the join, the result + * schema is similarly nested into a tuple under the column names `_1` and `_2`. * * This type of join can be useful both for preserving type-safety with the original object - * types as well as working with relational data where either side of the join has column - * names in common. - * - * @param other Right side of the join. - * @param condition Join expression. - * @param joinType Type of join to perform. Default `inner`. Must be one of: - * `inner`, `cross`, `outer`, `full`, `fullouter`,`full_outer`, `left`, - * `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`. + * types as well as working with relational data where either side of the join has column names + * in common. + * + * @param other + * Right side of the join. + * @param condition + * Join expression. + * @param joinType + * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`, + * `full`, `fullouter`,`full_outer`, `left`, `leftouter`, `left_outer`, `right`, `rightouter`, + * `right_outer`. * @group typedrel * @since 1.6.0 */ - def joinWith[U](other: DS[U], condition: Column, joinType: String): DS[(T, U)] + def joinWith[U](other: DS[U], condition: Column, joinType: String): Dataset[(T, U)] /** - * Using inner equi-join to join this Dataset returning a `Tuple2` for each pair - * where `condition` evaluates to true. + * Using inner equi-join to join this Dataset returning a `Tuple2` for each pair where + * `condition` evaluates to true. * - * @param other Right side of the join. - * @param condition Join expression. + * @param other + * Right side of the join. + * @param condition + * Join expression. * @group typedrel * @since 1.6.0 */ - def joinWith[U](other: DS[U], condition: Column): DS[(T, U)] = { + def joinWith[U](other: DS[U], condition: Column): Dataset[(T, U)] = { joinWith(other, condition, "inner") } - protected def sortInternal(global: Boolean, sortExprs: Seq[Column]): DS[T] + protected def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] /** * Returns a new Dataset with each partition sorted by the given expressions. @@ -803,8 +835,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def sortWithinPartitions(sortCol: String, sortCols: String*): DS[T] = { - sortWithinPartitions((sortCol +: sortCols).map(Column(_)) : _*) + def sortWithinPartitions(sortCol: String, sortCols: String*): Dataset[T] = { + sortWithinPartitions((sortCol +: sortCols).map(Column(_)): _*) } /** @@ -816,7 +848,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def sortWithinPartitions(sortExprs: Column*): DS[T] = { + def sortWithinPartitions(sortExprs: Column*): Dataset[T] = { sortInternal(global = false, sortExprs) } @@ -833,8 +865,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def sort(sortCol: String, sortCols: String*): DS[T] = { - sort((sortCol +: sortCols).map(Column(_)) : _*) + def sort(sortCol: String, sortCols: String*): Dataset[T] = { + sort((sortCol +: sortCols).map(Column(_)): _*) } /** @@ -847,33 +879,33 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def sort(sortExprs: Column*): DS[T] = { + def sort(sortExprs: Column*): Dataset[T] = { sortInternal(global = true, sortExprs) } /** - * Returns a new Dataset sorted by the given expressions. - * This is an alias of the `sort` function. + * Returns a new Dataset sorted by the given expressions. This is an alias of the `sort` + * function. * * @group typedrel * @since 2.0.0 */ @scala.annotation.varargs - def orderBy(sortCol: String, sortCols: String*): DS[T] = sort(sortCol, sortCols : _*) + def orderBy(sortCol: String, sortCols: String*): Dataset[T] = sort(sortCol, sortCols: _*) /** - * Returns a new Dataset sorted by the given expressions. - * This is an alias of the `sort` function. + * Returns a new Dataset sorted by the given expressions. This is an alias of the `sort` + * function. * * @group typedrel * @since 2.0.0 */ @scala.annotation.varargs - def orderBy(sortExprs: Column*): DS[T] = sort(sortExprs : _*) + def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs: _*) /** - * Specifies some hint on the current Dataset. As an example, the following code specifies - * that one of the plan can be broadcasted: + * Specifies some hint on the current Dataset. As an example, the following code specifies that + * one of the plan can be broadcasted: * * {{{ * df1.join(df2.hint("broadcast")) @@ -886,19 +918,22 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * df1.hint("rebalance", 10) * }}} * - * @param name the name of the hint - * @param parameters the parameters of the hint, all the parameters should be a `Column` or - * `Expression` or `Symbol` or could be converted into a `Literal` + * @param name + * the name of the hint + * @param parameters + * the parameters of the hint, all the parameters should be a `Column` or `Expression` or + * `Symbol` or could be converted into a `Literal` * @group basic * @since 2.2.0 */ @scala.annotation.varargs - def hint(name: String, parameters: Any*): DS[T] + def hint(name: String, parameters: Any*): Dataset[T] /** * Selects column based on the column name and returns it as a [[org.apache.spark.sql.Column]]. * - * @note The column name can also reference to a nested column like `a.b`. + * @note + * The column name can also reference to a nested column like `a.b`. * @group untypedrel * @since 2.0.0 */ @@ -907,7 +942,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { /** * Selects column based on the column name and returns it as a [[org.apache.spark.sql.Column]]. * - * @note The column name can also reference to a nested column like `a.b`. + * @note + * The column name can also reference to a nested column like `a.b`. * @group untypedrel * @since 2.0.0 */ @@ -940,7 +976,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def as(alias: String): DS[T] + def as(alias: String): Dataset[T] /** * (Scala-specific) Returns a new Dataset with an alias set. @@ -948,7 +984,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.0.0 */ - def as(alias: Symbol): DS[T] = as(alias.name) + def as(alias: Symbol): Dataset[T] = as(alias.name) /** * Returns a new Dataset with an alias set. Same as `as`. @@ -956,7 +992,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.0.0 */ - def alias(alias: String): DS[T] = as(alias) + def alias(alias: String): Dataset[T] = as(alias) /** * (Scala-specific) Returns a new Dataset with an alias set. Same as `as`. @@ -964,7 +1000,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.0.0 */ - def alias(alias: Symbol): DS[T] = as(alias) + def alias(alias: Symbol): Dataset[T] = as(alias) /** * Selects a set of column based expressions. @@ -976,11 +1012,11 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def select(cols: Column*): DS[Row] + def select(cols: Column*): Dataset[Row] /** - * Selects a set of columns. This is a variant of `select` that can only select - * existing columns using column names (i.e. cannot construct expressions). + * Selects a set of columns. This is a variant of `select` that can only select existing columns + * using column names (i.e. cannot construct expressions). * * {{{ * // The following two are equivalent: @@ -992,11 +1028,10 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def select(col: String, cols: String*): DS[Row] = select((col +: cols).map(Column(_)): _*) + def select(col: String, cols: String*): Dataset[Row] = select((col +: cols).map(Column(_)): _*) /** - * Selects a set of SQL expressions. This is a variant of `select` that accepts - * SQL expressions. + * Selects a set of SQL expressions. This is a variant of `select` that accepts SQL expressions. * * {{{ * // The following are equivalent: @@ -1008,7 +1043,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def selectExpr(exprs: String*): DS[Row] = select(exprs.map(functions.expr): _*) + def selectExpr(exprs: String*): Dataset[Row] = select(exprs.map(functions.expr): _*) /** * Returns a new Dataset by computing the given [[org.apache.spark.sql.Column]] expression for @@ -1022,14 +1057,14 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def select[U1](c1: TypedColumn[T, U1]): DS[U1] + def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] /** * Internal helper function for building typed selects that return tuples. For simplicity and - * code reuse, we do this without the help of the type system and then use helper functions - * that cast appropriately for the user facing interface. + * code reuse, we do this without the help of the type system and then use helper functions that + * cast appropriately for the user facing interface. */ - protected def selectUntyped(columns: TypedColumn[_, _]*): DS[_] + protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] /** * Returns a new Dataset by computing the given [[org.apache.spark.sql.Column]] expressions for @@ -1038,8 +1073,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): DS[(U1, U2)] = - selectUntyped(c1, c2).asInstanceOf[DS[(U1, U2)]] + def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] = + selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] /** * Returns a new Dataset by computing the given [[org.apache.spark.sql.Column]] expressions for @@ -1051,8 +1086,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { def select[U1, U2, U3]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], - c3: TypedColumn[T, U3]): DS[(U1, U2, U3)] = - selectUntyped(c1, c2, c3).asInstanceOf[DS[(U1, U2, U3)]] + c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] = + selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]] /** * Returns a new Dataset by computing the given [[org.apache.spark.sql.Column]] expressions for @@ -1065,8 +1100,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], c3: TypedColumn[T, U3], - c4: TypedColumn[T, U4]): DS[(U1, U2, U3, U4)] = - selectUntyped(c1, c2, c3, c4).asInstanceOf[DS[(U1, U2, U3, U4)]] + c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] = + selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]] /** * Returns a new Dataset by computing the given [[org.apache.spark.sql.Column]] expressions for @@ -1080,8 +1115,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { c2: TypedColumn[T, U2], c3: TypedColumn[T, U3], c4: TypedColumn[T, U4], - c5: TypedColumn[T, U5]): DS[(U1, U2, U3, U4, U5)] = - selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[DS[(U1, U2, U3, U4, U5)]] + c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] = + selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]] /** * Filters rows using the given condition. @@ -1094,7 +1129,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def filter(condition: Column): DS[T] + def filter(condition: Column): Dataset[T] /** * Filters rows using the given SQL expression. @@ -1105,26 +1140,26 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def filter(conditionExpr: String): DS[T] = + def filter(conditionExpr: String): Dataset[T] = filter(functions.expr(conditionExpr)) /** - * (Scala-specific) - * Returns a new Dataset that only contains elements where `func` returns `true`. + * (Scala-specific) Returns a new Dataset that only contains elements where `func` returns + * `true`. * * @group typedrel * @since 1.6.0 */ - def filter(func: T => Boolean): DS[T] + def filter(func: T => Boolean): Dataset[T] /** - * (Java-specific) - * Returns a new Dataset that only contains elements where `func` returns `true`. + * (Java-specific) Returns a new Dataset that only contains elements where `func` returns + * `true`. * * @group typedrel * @since 1.6.0 */ - def filter(func: FilterFunction[T]): DS[T] + def filter(func: FilterFunction[T]): Dataset[T] /** * Filters rows using the given condition. This is an alias for `filter`. @@ -1137,7 +1172,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def where(condition: Column): DS[T] = filter(condition) + def where(condition: Column): Dataset[T] = filter(condition) /** * Filters rows using the given SQL expression. @@ -1148,7 +1183,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def where(conditionExpr: String): DS[T] = filter(conditionExpr) + def where(conditionExpr: String): Dataset[T] = filter(conditionExpr) /** * Groups the Dataset using the specified columns, so we can run aggregation on them. See @@ -1169,14 +1204,14 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def groupBy(cols: Column*): RGD + def groupBy(cols: Column*): RelationalGroupedDataset /** - * Groups the Dataset using the specified columns, so that we can run aggregation on them. - * See [[RelationalGroupedDataset]] for all the available aggregate functions. + * Groups the Dataset using the specified columns, so that we can run aggregation on them. See + * [[RelationalGroupedDataset]] for all the available aggregate functions. * - * This is a variant of groupBy that can only group by existing columns using column names - * (i.e. cannot construct expressions). + * This is a variant of groupBy that can only group by existing columns using column names (i.e. + * cannot construct expressions). * * {{{ * // Compute the average for all numeric columns grouped by department. @@ -1193,12 +1228,13 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def groupBy(col1: String, cols: String*): RGD = groupBy((col1 +: cols).map(col): _*) + def groupBy(col1: String, cols: String*): RelationalGroupedDataset = groupBy( + (col1 +: cols).map(col): _*) /** - * Create a multi-dimensional rollup for the current Dataset using the specified columns, - * so we can run aggregation on them. - * See [[RelationalGroupedDataset]] for all the available aggregate functions. + * Create a multi-dimensional rollup for the current Dataset using the specified columns, so we + * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate + * functions. * * {{{ * // Compute the average for all numeric columns rolled up by department and group. @@ -1215,15 +1251,15 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def rollup(cols: Column*): RGD + def rollup(cols: Column*): RelationalGroupedDataset /** - * Create a multi-dimensional rollup for the current Dataset using the specified columns, - * so we can run aggregation on them. - * See [[RelationalGroupedDataset]] for all the available aggregate functions. + * Create a multi-dimensional rollup for the current Dataset using the specified columns, so we + * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate + * functions. * - * This is a variant of rollup that can only group by existing columns using column names - * (i.e. cannot construct expressions). + * This is a variant of rollup that can only group by existing columns using column names (i.e. + * cannot construct expressions). * * {{{ * // Compute the average for all numeric columns rolled up by department and group. @@ -1240,12 +1276,13 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def rollup(col1: String, cols: String*): RGD = rollup((col1 +: cols).map(col): _*) + def rollup(col1: String, cols: String*): RelationalGroupedDataset = rollup( + (col1 +: cols).map(col): _*) /** - * Create a multi-dimensional cube for the current Dataset using the specified columns, - * so we can run aggregation on them. - * See [[RelationalGroupedDataset]] for all the available aggregate functions. + * Create a multi-dimensional cube for the current Dataset using the specified columns, so we + * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate + * functions. * * {{{ * // Compute the average for all numeric columns cubed by department and group. @@ -1262,15 +1299,15 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def cube(cols: Column*): RGD + def cube(cols: Column*): RelationalGroupedDataset /** - * Create a multi-dimensional cube for the current Dataset using the specified columns, - * so we can run aggregation on them. - * See [[RelationalGroupedDataset]] for all the available aggregate functions. + * Create a multi-dimensional cube for the current Dataset using the specified columns, so we + * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate + * functions. * - * This is a variant of cube that can only group by existing columns using column names - * (i.e. cannot construct expressions). + * This is a variant of cube that can only group by existing columns using column names (i.e. + * cannot construct expressions). * * {{{ * // Compute the average for all numeric columns cubed by department and group. @@ -1287,12 +1324,13 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def cube(col1: String, cols: String*): RGD = cube((col1 +: cols).map(col): _*) + def cube(col1: String, cols: String*): RelationalGroupedDataset = cube( + (col1 +: cols).map(col): _*) /** - * Create multi-dimensional aggregation for the current Dataset using the specified grouping sets, - * so we can run aggregation on them. - * See [[RelationalGroupedDataset]] for all the available aggregate functions. + * Create multi-dimensional aggregation for the current Dataset using the specified grouping + * sets, so we can run aggregation on them. See [[RelationalGroupedDataset]] for all the + * available aggregate functions. * * {{{ * // Compute the average for all numeric columns group by specific grouping sets. @@ -1309,7 +1347,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 4.0.0 */ @scala.annotation.varargs - def groupingSets(groupingSets: Seq[Seq[Column]], cols: Column*): RGD + def groupingSets(groupingSets: Seq[Seq[Column]], cols: Column*): RelationalGroupedDataset /** * (Scala-specific) Aggregates on the entire Dataset without groups. @@ -1322,7 +1360,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def agg(aggExpr: (String, String), aggExprs: (String, String)*): DS[Row] = { + def agg(aggExpr: (String, String), aggExprs: (String, String)*): Dataset[Row] = { groupBy().agg(aggExpr, aggExprs: _*) } @@ -1337,7 +1375,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def agg(exprs: Map[String, String]): DS[Row] = groupBy().agg(exprs) + def agg(exprs: Map[String, String]): Dataset[Row] = groupBy().agg(exprs) /** * (Java-specific) Aggregates on the entire Dataset without groups. @@ -1350,7 +1388,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def agg(exprs: util.Map[String, String]): DS[Row] = groupBy().agg(exprs) + def agg(exprs: util.Map[String, String]): Dataset[Row] = groupBy().agg(exprs) /** * Aggregates on the entire Dataset without groups. @@ -1364,12 +1402,11 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DS[Row] = groupBy().agg(expr, exprs: _*) + def agg(expr: Column, exprs: Column*): Dataset[Row] = groupBy().agg(expr, exprs: _*) /** - * (Scala-specific) - * Reduces the elements of this Dataset using the specified binary function. The given `func` - * must be commutative and associative or the result may be non-deterministic. + * (Scala-specific) Reduces the elements of this Dataset using the specified binary function. + * The given `func` must be commutative and associative or the result may be non-deterministic. * * @group action * @since 1.6.0 @@ -1377,24 +1414,45 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { def reduce(func: (T, T) => T): T /** - * (Java-specific) - * Reduces the elements of this Dataset using the specified binary function. The given `func` - * must be commutative and associative or the result may be non-deterministic. + * (Java-specific) Reduces the elements of this Dataset using the specified binary function. The + * given `func` must be commutative and associative or the result may be non-deterministic. * * @group action * @since 1.6.0 */ - def reduce(func: ReduceFunction[T]): T = reduce(func.call) + def reduce(func: ReduceFunction[T]): T = reduce(ToScalaUDF(func)) + + /** + * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given + * key `func`. + * + * @group typedrel + * @since 2.0.0 + */ + def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] + + /** + * (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given + * key `func`. + * + * @group typedrel + * @since 2.0.0 + */ + def groupByKey[K]( + func: MapFunction[T, K], + encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = { + groupByKey(ToScalaUDF(func))(encoder) + } /** - * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set. - * This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, + * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns + * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, * which cannot be reversed. * - * This function is useful to massage a DataFrame into a format where some - * columns are identifier columns ("ids"), while all other columns ("values") - * are "unpivoted" to the rows, leaving just two non-id columns, named as given - * by `variableColumnName` and `valueColumnName`. + * This function is useful to massage a DataFrame into a format where some columns are + * identifier columns ("ids"), while all other columns ("values") are "unpivoted" to the rows, + * leaving just two non-id columns, named as given by `variableColumnName` and + * `valueColumnName`. * * {{{ * val df = Seq((1, 11, 12L), (2, 21, 22L)).toDF("id", "int", "long") @@ -1424,18 +1482,22 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * // |-- value: long (nullable = true) * }}} * - * When no "id" columns are given, the unpivoted DataFrame consists of only the - * "variable" and "value" columns. + * When no "id" columns are given, the unpivoted DataFrame consists of only the "variable" and + * "value" columns. * * All "value" columns must share a least common data type. Unless they are the same data type, - * all "value" columns are cast to the nearest common data type. For instance, - * types `IntegerType` and `LongType` are cast to `LongType`, while `IntegerType` and `StringType` - * do not have a common data type and `unpivot` fails with an `AnalysisException`. - * - * @param ids Id columns - * @param values Value columns to unpivot - * @param variableColumnName Name of the variable column - * @param valueColumnName Name of the value column + * all "value" columns are cast to the nearest common data type. For instance, types + * `IntegerType` and `LongType` are cast to `LongType`, while `IntegerType` and `StringType` do + * not have a common data type and `unpivot` fails with an `AnalysisException`. + * + * @param ids + * Id columns + * @param values + * Value columns to unpivot + * @param variableColumnName + * Name of the variable column + * @param valueColumnName + * Name of the value column * @group untypedrel * @since 3.4.0 */ @@ -1443,21 +1505,25 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { ids: Array[Column], values: Array[Column], variableColumnName: String, - valueColumnName: String): DS[Row] + valueColumnName: String): Dataset[Row] /** - * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set. - * This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, + * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns + * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, * which cannot be reversed. * - * @see `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)` + * @see + * `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)` * - * This is equivalent to calling `Dataset#unpivot(Array, Array, String, String)` - * where `values` is set to all non-id columns that exist in the DataFrame. + * This is equivalent to calling `Dataset#unpivot(Array, Array, String, String)` where `values` + * is set to all non-id columns that exist in the DataFrame. * - * @param ids Id columns - * @param variableColumnName Name of the variable column - * @param valueColumnName Name of the value column + * @param ids + * Id columns + * @param variableColumnName + * Name of the variable column + * @param valueColumnName + * Name of the value column * * @group untypedrel * @since 3.4.0 @@ -1465,18 +1531,23 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { def unpivot( ids: Array[Column], variableColumnName: String, - valueColumnName: String): DS[Row] + valueColumnName: String): Dataset[Row] /** - * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set. - * This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, + * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns + * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, * which cannot be reversed. This is an alias for `unpivot`. * - * @see `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)` - * @param ids Id columns - * @param values Value columns to unpivot - * @param variableColumnName Name of the variable column - * @param valueColumnName Name of the value column + * @see + * `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)` + * @param ids + * Id columns + * @param values + * Value columns to unpivot + * @param variableColumnName + * Name of the variable column + * @param valueColumnName + * Name of the value column * @group untypedrel * @since 3.4.0 */ @@ -1484,53 +1555,135 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { ids: Array[Column], values: Array[Column], variableColumnName: String, - valueColumnName: String): DS[Row] = + valueColumnName: String): Dataset[Row] = unpivot(ids, values, variableColumnName, valueColumnName) /** - * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set. - * This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, + * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns + * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, * which cannot be reversed. This is an alias for `unpivot`. * - * @see `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)` - * - * This is equivalent to calling `Dataset#unpivot(Array, Array, String, String)` - * where `values` is set to all non-id columns that exist in the DataFrame. - * @param ids Id columns - * @param variableColumnName Name of the variable column - * @param valueColumnName Name of the value column + * @see + * `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)` + * + * This is equivalent to calling `Dataset#unpivot(Array, Array, String, String)` where `values` + * is set to all non-id columns that exist in the DataFrame. + * @param ids + * Id columns + * @param variableColumnName + * Name of the variable column + * @param valueColumnName + * Name of the value column * @group untypedrel * @since 3.4.0 */ def melt( ids: Array[Column], variableColumnName: String, - valueColumnName: String): DS[Row] = + valueColumnName: String): Dataset[Row] = unpivot(ids, variableColumnName, valueColumnName) - /** - * Define (named) metrics to observe on the Dataset. This method returns an 'observed' Dataset - * that returns the same result as the input, with the following guarantees: - *
      - *
    • It will compute the defined aggregates (metrics) on all the data that is flowing through - * the Dataset at that point.
    • - *
    • It will report the value of the defined aggregate columns as soon as we reach a completion - * point. A completion point is either the end of a query (batch mode) or the end of a streaming - * epoch. The value of the aggregates only reflects the data processed since the previous - * completion point.
    • - *
    - * Please note that continuous execution is currently not supported. - * - * The metrics columns must either contain a literal (e.g. lit(42)), or should contain one or - * more aggregate functions (e.g. sum(a) or sum(a + b) + avg(c) - lit(1)). Expressions that - * contain references to the input Dataset's columns must always be wrapped in an aggregate - * function. - * - * @group typedrel - * @since 3.0.0 - */ + /** + * Transposes a DataFrame such that the values in the specified index column become the new + * columns of the DataFrame. + * + * Please note: + * - All columns except the index column must share a least common data type. Unless they are + * the same data type, all columns are cast to the nearest common data type. + * - The name of the column into which the original column names are transposed defaults to + * "key". + * - null values in the index column are excluded from the column names for the transposed + * table, which are ordered in ascending order. + * + * {{{ + * val df = Seq(("A", 1, 2), ("B", 3, 4)).toDF("id", "val1", "val2") + * df.show() + * // output: + * // +---+----+----+ + * // | id|val1|val2| + * // +---+----+----+ + * // | A| 1| 2| + * // | B| 3| 4| + * // +---+----+----+ + * + * df.transpose($"id").show() + * // output: + * // +----+---+---+ + * // | key| A| B| + * // +----+---+---+ + * // |val1| 1| 3| + * // |val2| 2| 4| + * // +----+---+---+ + * // schema: + * // root + * // |-- key: string (nullable = false) + * // |-- A: integer (nullable = true) + * // |-- B: integer (nullable = true) + * + * df.transpose().show() + * // output: + * // +----+---+---+ + * // | key| A| B| + * // +----+---+---+ + * // |val1| 1| 3| + * // |val2| 2| 4| + * // +----+---+---+ + * // schema: + * // root + * // |-- key: string (nullable = false) + * // |-- A: integer (nullable = true) + * // |-- B: integer (nullable = true) + * }}} + * + * @param indexColumn + * The single column that will be treated as the index for the transpose operation. This + * column will be used to pivot the data, transforming the DataFrame such that the values of + * the indexColumn become the new columns in the transposed DataFrame. + * + * @group untypedrel + * @since 4.0.0 + */ + def transpose(indexColumn: Column): Dataset[Row] + + /** + * Transposes a DataFrame, switching rows to columns. This function transforms the DataFrame + * such that the values in the first column become the new columns of the DataFrame. + * + * This is equivalent to calling `Dataset#transpose(Column)` where `indexColumn` is set to the + * first column. + * + * Please note: + * - All columns except the index column must share a least common data type. Unless they are + * the same data type, all columns are cast to the nearest common data type. + * - The name of the column into which the original column names are transposed defaults to + * "key". + * - Non-"key" column names for the transposed table are ordered in ascending order. + * + * @group untypedrel + * @since 4.0.0 + */ + def transpose(): Dataset[Row] + + /** + * Define (named) metrics to observe on the Dataset. This method returns an 'observed' Dataset + * that returns the same result as the input, with the following guarantees:
    • It will + * compute the defined aggregates (metrics) on all the data that is flowing through the Dataset + * at that point.
    • It will report the value of the defined aggregate columns as soon as + * we reach a completion point. A completion point is either the end of a query (batch mode) or + * the end of a streaming epoch. The value of the aggregates only reflects the data processed + * since the previous completion point.
    Please note that continuous execution is + * currently not supported. + * + * The metrics columns must either contain a literal (e.g. lit(42)), or should contain one or + * more aggregate functions (e.g. sum(a) or sum(a + b) + avg(c) - lit(1)). Expressions that + * contain references to the input Dataset's columns must always be wrapped in an aggregate + * function. + * + * @group typedrel + * @since 3.0.0 + */ @scala.annotation.varargs - def observe(name: String, expr: Column, exprs: Column*): DS[T] + def observe(name: String, expr: Column, exprs: Column*): Dataset[T] /** * Observe (named) metrics through an `org.apache.spark.sql.Observation` instance. This method @@ -1546,23 +1699,24 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * val metrics = observation.get * }}} * - * @throws IllegalArgumentException If this is a streaming Dataset (this.isStreaming == true) + * @throws IllegalArgumentException + * If this is a streaming Dataset (this.isStreaming == true) * * @group typedrel * @since 3.3.0 */ @scala.annotation.varargs - def observe(observation: Observation, expr: Column, exprs: Column*): DS[T] + def observe(observation: Observation, expr: Column, exprs: Column*): Dataset[T] /** - * Returns a new Dataset by taking the first `n` rows. The difference between this function - * and `head` is that `head` is an action and returns an array (by triggering query execution) - * while `limit` returns a new Dataset. + * Returns a new Dataset by taking the first `n` rows. The difference between this function and + * `head` is that `head` is an action and returns an array (by triggering query execution) while + * `limit` returns a new Dataset. * * @group typedrel * @since 2.0.0 */ - def limit(n: Int): DS[T] + def limit(n: Int): Dataset[T] /** * Returns a new Dataset by skipping the first `n` rows. @@ -1570,7 +1724,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 3.4.0 */ - def offset(n: Int): DS[T] + def offset(n: Int): Dataset[T] /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. @@ -1594,19 +1748,19 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * // +----+----+----+ * }}} * - * Notice that the column positions in the schema aren't necessarily matched with the - * fields in the strongly typed objects in a Dataset. This function resolves columns - * by their positions in the schema, not the fields in the strongly typed objects. Use - * [[unionByName]] to resolve columns by field name in the typed objects. + * Notice that the column positions in the schema aren't necessarily matched with the fields in + * the strongly typed objects in a Dataset. This function resolves columns by their positions in + * the schema, not the fields in the strongly typed objects. Use [[unionByName]] to resolve + * columns by field name in the typed objects. * * @group typedrel * @since 2.0.0 */ - def union(other: DS[T]): DS[T] + def union(other: DS[T]): Dataset[T] /** - * Returns a new Dataset containing union of rows in this Dataset and another Dataset. - * This is an alias for `union`. + * Returns a new Dataset containing union of rows in this Dataset and another Dataset. This is + * an alias for `union`. * * This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does * deduplication of elements), use this function followed by a [[distinct]]. @@ -1616,7 +1770,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.0.0 */ - def unionAll(other: DS[T]): DS[T] = union(other) + def unionAll(other: DS[T]): Dataset[T] = union(other) /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. @@ -1624,8 +1778,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * This is different from both `UNION ALL` and `UNION DISTINCT` in SQL. To do a SQL-style set * union (that does deduplication of elements), use this function followed by a [[distinct]]. * - * The difference between this function and [[union]] is that this function - * resolves columns by name (not by position): + * The difference between this function and [[union]] is that this function resolves columns by + * name (not by position): * * {{{ * val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2") @@ -1647,18 +1801,17 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.3.0 */ - def unionByName(other: DS[T]): DS[T] = unionByName(other, allowMissingColumns = false) + def unionByName(other: DS[T]): Dataset[T] = unionByName(other, allowMissingColumns = false) /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. * - * The difference between this function and [[union]] is that this function - * resolves columns by name (not by position). + * The difference between this function and [[union]] is that this function resolves columns by + * name (not by position). * - * When the parameter `allowMissingColumns` is `true`, the set of column names - * in this and other `Dataset` can differ; missing columns will be filled with null. - * Further, the missing columns of this `Dataset` will be added at the end - * in the schema of the union result: + * When the parameter `allowMissingColumns` is `true`, the set of column names in this and other + * `Dataset` can differ; missing columns will be filled with null. Further, the missing columns + * of this `Dataset` will be added at the end in the schema of the union result: * * {{{ * val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2") @@ -1692,150 +1845,169 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 3.1.0 */ - def unionByName(other: DS[T], allowMissingColumns: Boolean): DS[T] + def unionByName(other: DS[T], allowMissingColumns: Boolean): Dataset[T] /** - * Returns a new Dataset containing rows only in both this Dataset and another Dataset. - * This is equivalent to `INTERSECT` in SQL. + * Returns a new Dataset containing rows only in both this Dataset and another Dataset. This is + * equivalent to `INTERSECT` in SQL. * - * @note Equality checking is performed directly on the encoded representation of the data - * and thus is not affected by a custom `equals` function defined on `T`. + * @note + * Equality checking is performed directly on the encoded representation of the data and thus + * is not affected by a custom `equals` function defined on `T`. * @group typedrel * @since 1.6.0 */ - def intersect(other: DS[T]): DS[T] + def intersect(other: DS[T]): Dataset[T] /** * Returns a new Dataset containing rows only in both this Dataset and another Dataset while - * preserving the duplicates. - * This is equivalent to `INTERSECT ALL` in SQL. + * preserving the duplicates. This is equivalent to `INTERSECT ALL` in SQL. * - * @note Equality checking is performed directly on the encoded representation of the data - * and thus is not affected by a custom `equals` function defined on `T`. Also as standard - * in SQL, this function resolves columns by position (not by name). + * @note + * Equality checking is performed directly on the encoded representation of the data and thus + * is not affected by a custom `equals` function defined on `T`. Also as standard in SQL, this + * function resolves columns by position (not by name). * @group typedrel * @since 2.4.0 */ - def intersectAll(other: DS[T]): DS[T] + def intersectAll(other: DS[T]): Dataset[T] /** - * Returns a new Dataset containing rows in this Dataset but not in another Dataset. - * This is equivalent to `EXCEPT DISTINCT` in SQL. + * Returns a new Dataset containing rows in this Dataset but not in another Dataset. This is + * equivalent to `EXCEPT DISTINCT` in SQL. * - * @note Equality checking is performed directly on the encoded representation of the data - * and thus is not affected by a custom `equals` function defined on `T`. + * @note + * Equality checking is performed directly on the encoded representation of the data and thus + * is not affected by a custom `equals` function defined on `T`. * @group typedrel * @since 2.0.0 */ - def except(other: DS[T]): DS[T] + def except(other: DS[T]): Dataset[T] /** * Returns a new Dataset containing rows in this Dataset but not in another Dataset while - * preserving the duplicates. - * This is equivalent to `EXCEPT ALL` in SQL. + * preserving the duplicates. This is equivalent to `EXCEPT ALL` in SQL. * - * @note Equality checking is performed directly on the encoded representation of the data - * and thus is not affected by a custom `equals` function defined on `T`. Also as standard - * in SQL, this function resolves columns by position (not by name). + * @note + * Equality checking is performed directly on the encoded representation of the data and thus + * is not affected by a custom `equals` function defined on `T`. Also as standard in SQL, this + * function resolves columns by position (not by name). * @group typedrel * @since 2.4.0 */ - def exceptAll(other: DS[T]): DS[T] + def exceptAll(other: DS[T]): Dataset[T] /** - * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement), - * using a user-supplied seed. + * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement), using a + * user-supplied seed. * - * @param fraction Fraction of rows to generate, range [0.0, 1.0]. - * @param seed Seed for sampling. - * @note This is NOT guaranteed to provide exactly the fraction of the count - * of the given [[Dataset]]. + * @param fraction + * Fraction of rows to generate, range [0.0, 1.0]. + * @param seed + * Seed for sampling. + * @note + * This is NOT guaranteed to provide exactly the fraction of the count of the given + * [[Dataset]]. * @group typedrel * @since 2.3.0 */ - def sample(fraction: Double, seed: Long): DS[T] = { + def sample(fraction: Double, seed: Long): Dataset[T] = { sample(withReplacement = false, fraction = fraction, seed = seed) } /** - * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement), - * using a random seed. + * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement), using a + * random seed. * - * @param fraction Fraction of rows to generate, range [0.0, 1.0]. - * @note This is NOT guaranteed to provide exactly the fraction of the count - * of the given [[Dataset]]. + * @param fraction + * Fraction of rows to generate, range [0.0, 1.0]. + * @note + * This is NOT guaranteed to provide exactly the fraction of the count of the given + * [[Dataset]]. * @group typedrel * @since 2.3.0 */ - def sample(fraction: Double): DS[T] = { + def sample(fraction: Double): Dataset[T] = { sample(withReplacement = false, fraction = fraction) } /** * Returns a new [[Dataset]] by sampling a fraction of rows, using a user-supplied seed. * - * @param withReplacement Sample with replacement or not. - * @param fraction Fraction of rows to generate, range [0.0, 1.0]. - * @param seed Seed for sampling. - * @note This is NOT guaranteed to provide exactly the fraction of the count - * of the given [[Dataset]]. + * @param withReplacement + * Sample with replacement or not. + * @param fraction + * Fraction of rows to generate, range [0.0, 1.0]. + * @param seed + * Seed for sampling. + * @note + * This is NOT guaranteed to provide exactly the fraction of the count of the given + * [[Dataset]]. * @group typedrel * @since 1.6.0 */ - def sample(withReplacement: Boolean, fraction: Double, seed: Long): DS[T] + def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] /** * Returns a new [[Dataset]] by sampling a fraction of rows, using a random seed. * - * @param withReplacement Sample with replacement or not. - * @param fraction Fraction of rows to generate, range [0.0, 1.0]. + * @param withReplacement + * Sample with replacement or not. + * @param fraction + * Fraction of rows to generate, range [0.0, 1.0]. * - * @note This is NOT guaranteed to provide exactly the fraction of the total count - * of the given [[Dataset]]. + * @note + * This is NOT guaranteed to provide exactly the fraction of the total count of the given + * [[Dataset]]. * * @group typedrel * @since 1.6.0 */ - def sample(withReplacement: Boolean, fraction: Double): DS[T] = { + def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = { sample(withReplacement, fraction, SparkClassUtils.random.nextLong) } /** * Randomly splits this Dataset with the provided weights. * - * @param weights weights for splits, will be normalized if they don't sum to 1. - * @param seed Seed for sampling. + * @param weights + * weights for splits, will be normalized if they don't sum to 1. + * @param seed + * Seed for sampling. * * For Java API, use [[randomSplitAsList]]. * * @group typedrel * @since 2.0.0 */ - def randomSplit(weights: Array[Double], seed: Long): Array[_ <: DS[T]] + def randomSplit(weights: Array[Double], seed: Long): Array[_ <: Dataset[T]] /** * Returns a Java list that contains randomly split Dataset with the provided weights. * - * @param weights weights for splits, will be normalized if they don't sum to 1. - * @param seed Seed for sampling. + * @param weights + * weights for splits, will be normalized if they don't sum to 1. + * @param seed + * Seed for sampling. * @group typedrel * @since 2.0.0 */ - def randomSplitAsList(weights: Array[Double], seed: Long): util.List[_ <: DS[T]] + def randomSplitAsList(weights: Array[Double], seed: Long): util.List[_ <: Dataset[T]] /** * Randomly splits this Dataset with the provided weights. * - * @param weights weights for splits, will be normalized if they don't sum to 1. + * @param weights + * weights for splits, will be normalized if they don't sum to 1. * @group typedrel * @since 2.0.0 */ - def randomSplit(weights: Array[Double]): Array[_ <: DS[T]] + def randomSplit(weights: Array[Double]): Array[_ <: Dataset[T]] /** - * (Scala-specific) Returns a new Dataset where each row has been expanded to zero or more - * rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of - * the input row are implicitly joined with each row that is output by the function. + * (Scala-specific) Returns a new Dataset where each row has been expanded to zero or more rows + * by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of the + * input row are implicitly joined with each row that is output by the function. * * Given that this is deprecated, as an alternative, you can explode columns either using * `functions.explode()` or `flatMap()`. The following example uses these alternatives to count @@ -1843,7 +2015,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * * {{{ * case class Book(title: String, words: String) - * val ds: DS[Book] + * val ds: Dataset[Book] * * val allWords = ds.select($"title", explode(split($"words", " ")).as("word")) * @@ -1860,12 +2032,13 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") - def explode[A <: Product : TypeTag](input: Column*)(f: Row => IterableOnce[A]): DS[Row] + def explode[A <: Product: TypeTag](input: Column*)(f: Row => IterableOnce[A]): Dataset[Row] /** - * (Scala-specific) Returns a new Dataset where a single column has been expanded to zero - * or more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All - * columns of the input row are implicitly joined with each value that is output by the function. + * (Scala-specific) Returns a new Dataset where a single column has been expanded to zero or + * more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All + * columns of the input row are implicitly joined with each value that is output by the + * function. * * Given that this is deprecated, as an alternative, you can explode columns either using * `functions.explode()`: @@ -1884,24 +2057,25 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") - def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A => IterableOnce[B]) - : DS[Row] + def explode[A, B: TypeTag](inputColumn: String, outputColumn: String)( + f: A => IterableOnce[B]): Dataset[Row] /** - * Returns a new Dataset by adding a column or replacing the existing column that has - * the same name. + * Returns a new Dataset by adding a column or replacing the existing column that has the same + * name. * - * `column`'s expression must only refer to attributes supplied by this Dataset. It is an - * error to add a column that refers to some other Dataset. + * `column`'s expression must only refer to attributes supplied by this Dataset. It is an error + * to add a column that refers to some other Dataset. * - * @note this method introduces a projection internally. Therefore, calling it multiple times, - * for instance, via loops in order to add multiple columns can generate big plans which - * can cause performance issues and even `StackOverflowException`. To avoid this, - * use `select` with the multiple columns at once. + * @note + * this method introduces a projection internally. Therefore, calling it multiple times, for + * instance, via loops in order to add multiple columns can generate big plans which can cause + * performance issues and even `StackOverflowException`. To avoid this, use `select` with the + * multiple columns at once. * @group untypedrel * @since 2.0.0 */ - def withColumn(colName: String, col: Column): DS[Row] = withColumns(Seq(colName), Seq(col)) + def withColumn(colName: String, col: Column): Dataset[Row] = withColumns(Seq(colName), Seq(col)) /** * (Scala-specific) Returns a new Dataset by adding columns or replacing the existing columns @@ -1913,7 +2087,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 3.3.0 */ - def withColumns(colsMap: Map[String, Column]): DS[Row] = { + def withColumns(colsMap: Map[String, Column]): Dataset[Row] = { val (colNames, newCols) = colsMap.toSeq.unzip withColumns(colNames, newCols) } @@ -1928,60 +2102,55 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 3.3.0 */ - def withColumns(colsMap: util.Map[String, Column]): DS[Row] = withColumns( - colsMap.asScala.toMap - ) + def withColumns(colsMap: util.Map[String, Column]): Dataset[Row] = withColumns( + colsMap.asScala.toMap) /** - * Returns a new Dataset by adding columns or replacing the existing columns that has - * the same names. + * Returns a new Dataset by adding columns or replacing the existing columns that has the same + * names. */ - protected def withColumns(colNames: Seq[String], cols: Seq[Column]): DS[Row] + protected def withColumns(colNames: Seq[String], cols: Seq[Column]): Dataset[Row] /** - * Returns a new Dataset with a column renamed. - * This is a no-op if schema doesn't contain existingName. + * Returns a new Dataset with a column renamed. This is a no-op if schema doesn't contain + * existingName. * * @group untypedrel * @since 2.0.0 */ - def withColumnRenamed(existingName: String, newName: String): DS[Row] = + def withColumnRenamed(existingName: String, newName: String): Dataset[Row] = withColumnsRenamed(Seq(existingName), Seq(newName)) /** - * (Scala-specific) - * Returns a new Dataset with a columns renamed. - * This is a no-op if schema doesn't contain existingName. + * (Scala-specific) Returns a new Dataset with a columns renamed. This is a no-op if schema + * doesn't contain existingName. * * `colsMap` is a map of existing column name and new column name. * - * @throws org.apache.spark.sql.AnalysisException if there are duplicate names in resulting - * projection + * @throws org.apache.spark.sql.AnalysisException + * if there are duplicate names in resulting projection * @group untypedrel * @since 3.4.0 */ @throws[AnalysisException] - def withColumnsRenamed(colsMap: Map[String, String]): DS[Row] = { + def withColumnsRenamed(colsMap: Map[String, String]): Dataset[Row] = { val (colNames, newColNames) = colsMap.toSeq.unzip withColumnsRenamed(colNames, newColNames) } /** - * (Java-specific) - * Returns a new Dataset with a columns renamed. - * This is a no-op if schema doesn't contain existingName. + * (Java-specific) Returns a new Dataset with a columns renamed. This is a no-op if schema + * doesn't contain existingName. * * `colsMap` is a map of existing column name and new column name. * * @group untypedrel * @since 3.4.0 */ - def withColumnsRenamed(colsMap: util.Map[String, String]): DS[Row] = + def withColumnsRenamed(colsMap: util.Map[String, String]): Dataset[Row] = withColumnsRenamed(colsMap.asScala.toMap) - protected def withColumnsRenamed( - colNames: Seq[String], - newColNames: Seq[String]): DS[Row] + protected def withColumnsRenamed(colNames: Seq[String], newColNames: Seq[String]): Dataset[Row] /** * Returns a new Dataset by updating an existing column with metadata. @@ -1989,17 +2158,17 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 3.3.0 */ - def withMetadata(columnName: String, metadata: Metadata): DS[Row] + def withMetadata(columnName: String, metadata: Metadata): Dataset[Row] /** - * Returns a new Dataset with a column dropped. This is a no-op if schema doesn't contain - * column name. + * Returns a new Dataset with a column dropped. This is a no-op if schema doesn't contain column + * name. * * This method can only be used to drop top level columns. the colName string is treated * literally without further interpretation. * - * Note: `drop(colName)` has different semantic with `drop(col(colName))`, for example: - * 1, multi column have the same colName: + * Note: `drop(colName)` has different semantic with `drop(col(colName))`, for example: 1, multi + * column have the same colName: * {{{ * val df1 = spark.range(0, 2).withColumn("key1", lit(1)) * val df2 = spark.range(0, 2).withColumn("key2", lit(2)) @@ -2062,111 +2231,108 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def drop(colName: String): DS[Row] = drop(colName :: Nil : _*) + def drop(colName: String): Dataset[Row] = drop(colName :: Nil: _*) /** - * Returns a new Dataset with columns dropped. - * This is a no-op if schema doesn't contain column name(s). + * Returns a new Dataset with columns dropped. This is a no-op if schema doesn't contain column + * name(s). * - * This method can only be used to drop top level columns. the colName string is treated literally - * without further interpretation. + * This method can only be used to drop top level columns. the colName string is treated + * literally without further interpretation. * * @group untypedrel * @since 2.0.0 */ @scala.annotation.varargs - def drop(colNames: String*): DS[Row] + def drop(colNames: String*): Dataset[Row] /** * Returns a new Dataset with column dropped. * - * This method can only be used to drop top level column. - * This version of drop accepts a [[org.apache.spark.sql.Column]] rather than a name. - * This is a no-op if the Dataset doesn't have a column - * with an equivalent expression. + * This method can only be used to drop top level column. This version of drop accepts a + * [[org.apache.spark.sql.Column]] rather than a name. This is a no-op if the Dataset doesn't + * have a column with an equivalent expression. * - * Note: `drop(col(colName))` has different semantic with `drop(colName)`, - * please refer to `Dataset#drop(colName: String)`. + * Note: `drop(col(colName))` has different semantic with `drop(colName)`, please refer to + * `Dataset#drop(colName: String)`. * * @group untypedrel * @since 2.0.0 */ - def drop(col: Column): DS[Row] = drop(col, Nil : _*) + def drop(col: Column): Dataset[Row] = drop(col, Nil: _*) /** * Returns a new Dataset with columns dropped. * - * This method can only be used to drop top level columns. - * This is a no-op if the Dataset doesn't have a columns - * with an equivalent expression. + * This method can only be used to drop top level columns. This is a no-op if the Dataset + * doesn't have a columns with an equivalent expression. * * @group untypedrel * @since 3.4.0 */ @scala.annotation.varargs - def drop(col: Column, cols: Column*): DS[Row] + def drop(col: Column, cols: Column*): Dataset[Row] /** - * Returns a new Dataset that contains only the unique rows from this Dataset. - * This is an alias for `distinct`. + * Returns a new Dataset that contains only the unique rows from this Dataset. This is an alias + * for `distinct`. * * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it * will keep all data across triggers as intermediate state to drop duplicates rows. You can use - * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit - * the state. In addition, too late data older than watermark will be dropped to avoid any + * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly + * limit the state. In addition, too late data older than watermark will be dropped to avoid any * possibility of duplicates. * * @group typedrel * @since 2.0.0 */ - def dropDuplicates(): DS[T] + def dropDuplicates(): Dataset[T] /** - * (Scala-specific) Returns a new Dataset with duplicate rows removed, considering only - * the subset of columns. + * (Scala-specific) Returns a new Dataset with duplicate rows removed, considering only the + * subset of columns. * * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it * will keep all data across triggers as intermediate state to drop duplicates rows. You can use - * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit - * the state. In addition, too late data older than watermark will be dropped to avoid any + * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly + * limit the state. In addition, too late data older than watermark will be dropped to avoid any * possibility of duplicates. * * @group typedrel * @since 2.0.0 */ - def dropDuplicates(colNames: Seq[String]): DS[T] + def dropDuplicates(colNames: Seq[String]): Dataset[T] /** - * Returns a new Dataset with duplicate rows removed, considering only - * the subset of columns. + * Returns a new Dataset with duplicate rows removed, considering only the subset of columns. * * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it * will keep all data across triggers as intermediate state to drop duplicates rows. You can use - * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit - * the state. In addition, too late data older than watermark will be dropped to avoid any + * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly + * limit the state. In addition, too late data older than watermark will be dropped to avoid any * possibility of duplicates. * * @group typedrel * @since 2.0.0 */ - def dropDuplicates(colNames: Array[String]): DS[T] = + def dropDuplicates(colNames: Array[String]): Dataset[T] = dropDuplicates(colNames.toImmutableArraySeq) /** - * Returns a new [[Dataset]] with duplicate rows removed, considering only - * the subset of columns. + * Returns a new [[Dataset]] with duplicate rows removed, considering only the subset of + * columns. * * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it * will keep all data across triggers as intermediate state to drop duplicates rows. You can use - * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit - * the state. In addition, too late data older than watermark will be dropped to avoid any + * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly + * limit the state. In addition, too late data older than watermark will be dropped to avoid any * possibility of duplicates. * * @group typedrel * @since 2.0.0 */ @scala.annotation.varargs - def dropDuplicates(col1: String, cols: String*): DS[T] = { + def dropDuplicates(col1: String, cols: String*): Dataset[T] = { val colNames: Seq[String] = col1 +: cols dropDuplicates(colNames) } @@ -2177,8 +2343,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * This only works with streaming [[Dataset]], and watermark for the input [[Dataset]] must be * set via [[withWatermark]]. * - * For a streaming [[Dataset]], this will keep all data across triggers as intermediate state - * to drop duplicated rows. The state will be kept to guarantee the semantic, "Events are + * For a streaming [[Dataset]], this will keep all data across triggers as intermediate state to + * drop duplicated rows. The state will be kept to guarantee the semantic, "Events are * deduplicated as long as the time distance of earliest and latest events are smaller than the * delay threshold of watermark." Users are encouraged to set the delay threshold of watermark * longer than max timestamp differences among duplicated events. @@ -2188,7 +2354,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 3.5.0 */ - def dropDuplicatesWithinWatermark(): DS[T] + def dropDuplicatesWithinWatermark(): Dataset[T] /** * Returns a new Dataset with duplicates rows removed, considering only the subset of columns, @@ -2197,8 +2363,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * This only works with streaming [[Dataset]], and watermark for the input [[Dataset]] must be * set via [[withWatermark]]. * - * For a streaming [[Dataset]], this will keep all data across triggers as intermediate state - * to drop duplicated rows. The state will be kept to guarantee the semantic, "Events are + * For a streaming [[Dataset]], this will keep all data across triggers as intermediate state to + * drop duplicated rows. The state will be kept to guarantee the semantic, "Events are * deduplicated as long as the time distance of earliest and latest events are smaller than the * delay threshold of watermark." Users are encouraged to set the delay threshold of watermark * longer than max timestamp differences among duplicated events. @@ -2208,7 +2374,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 3.5.0 */ - def dropDuplicatesWithinWatermark(colNames: Seq[String]): DS[T] + def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] /** * Returns a new Dataset with duplicates rows removed, considering only the subset of columns, @@ -2217,8 +2383,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * This only works with streaming [[Dataset]], and watermark for the input [[Dataset]] must be * set via [[withWatermark]]. * - * For a streaming [[Dataset]], this will keep all data across triggers as intermediate state - * to drop duplicated rows. The state will be kept to guarantee the semantic, "Events are + * For a streaming [[Dataset]], this will keep all data across triggers as intermediate state to + * drop duplicated rows. The state will be kept to guarantee the semantic, "Events are * deduplicated as long as the time distance of earliest and latest events are smaller than the * delay threshold of watermark." Users are encouraged to set the delay threshold of watermark * longer than max timestamp differences among duplicated events. @@ -2228,7 +2394,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 3.5.0 */ - def dropDuplicatesWithinWatermark(colNames: Array[String]): DS[T] = { + def dropDuplicatesWithinWatermark(colNames: Array[String]): Dataset[T] = { dropDuplicatesWithinWatermark(colNames.toImmutableArraySeq) } @@ -2239,8 +2405,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * This only works with streaming [[Dataset]], and watermark for the input [[Dataset]] must be * set via [[withWatermark]]. * - * For a streaming [[Dataset]], this will keep all data across triggers as intermediate state - * to drop duplicated rows. The state will be kept to guarantee the semantic, "Events are + * For a streaming [[Dataset]], this will keep all data across triggers as intermediate state to + * drop duplicated rows. The state will be kept to guarantee the semantic, "Events are * deduplicated as long as the time distance of earliest and latest events are smaller than the * delay threshold of watermark." Users are encouraged to set the delay threshold of watermark * longer than max timestamp differences among duplicated events. @@ -2251,7 +2417,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 3.5.0 */ @scala.annotation.varargs - def dropDuplicatesWithinWatermark(col1: String, cols: String*): DS[T] = { + def dropDuplicatesWithinWatermark(col1: String, cols: String*): Dataset[T] = { val colNames: Seq[String] = col1 +: cols dropDuplicatesWithinWatermark(colNames) } @@ -2279,28 +2445,22 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * * Use [[summary]] for expanded statistics and control over which statistics to compute. * - * @param cols Columns to compute statistics on. + * @param cols + * Columns to compute statistics on. * @group action * @since 1.6.0 */ @scala.annotation.varargs - def describe(cols: String*): DS[Row] - - /** - * Computes specified statistics for numeric and string columns. Available statistics are: - *
      - *
    • count
    • - *
    • mean
    • - *
    • stddev
    • - *
    • min
    • - *
    • max
    • - *
    • arbitrary approximate percentiles specified as a percentage (e.g. 75%)
    • - *
    • count_distinct
    • - *
    • approx_count_distinct
    • - *
    + def describe(cols: String*): Dataset[Row] + + /** + * Computes specified statistics for numeric and string columns. Available statistics are:
      + *
    • count
    • mean
    • stddev
    • min
    • max
    • arbitrary + * approximate percentiles specified as a percentage (e.g. 75%)
    • count_distinct
    • + *
    • approx_count_distinct
    * - * If no statistics are given, this function computes count, mean, stddev, min, - * approximate quartiles (percentiles at 25%, 50%, and 75%), and max. + * If no statistics are given, this function computes count, mean, stddev, min, approximate + * quartiles (percentiles at 25%, 50%, and 75%), and max. * * This function is meant for exploratory data analysis, as we make no guarantee about the * backward compatibility of the schema of the resulting Dataset. If you want to @@ -2355,18 +2515,20 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * * See also [[describe]] for basic statistics. * - * @param statistics Statistics from above list to be computed. + * @param statistics + * Statistics from above list to be computed. * @group action * @since 2.3.0 */ @scala.annotation.varargs - def summary(statistics: String*): DS[Row] + def summary(statistics: String*): Dataset[Row] /** * Returns the first `n` rows. * - * @note this method should only be used if the resulting array is expected to be small, as - * all the data is loaded into the driver's memory. + * @note + * this method should only be used if the resulting array is expected to be small, as all the + * data is loaded into the driver's memory. * @group action * @since 1.6.0 */ @@ -2391,7 +2553,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { /** * Concise syntax for chaining custom transformations. * {{{ - * def featurize(ds: DS[T]): DS[U] = ... + * def featurize(ds: Dataset[T]): Dataset[U] = ... * * ds * .transform(featurize) @@ -2401,66 +2563,64 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def transform[U](t: DS[T] => DS[U]): DS[U] = t(this.asInstanceOf[DS[T]]) + def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this.asInstanceOf[Dataset[T]]) /** - * (Scala-specific) - * Returns a new Dataset that contains the result of applying `func` to each element. + * (Scala-specific) Returns a new Dataset that contains the result of applying `func` to each + * element. * * @group typedrel * @since 1.6.0 */ - def map[U: Encoder](func: T => U): DS[U] + def map[U: Encoder](func: T => U): Dataset[U] /** - * (Java-specific) - * Returns a new Dataset that contains the result of applying `func` to each element. + * (Java-specific) Returns a new Dataset that contains the result of applying `func` to each + * element. * * @group typedrel * @since 1.6.0 */ - def map[U](func: MapFunction[T, U], encoder: Encoder[U]): DS[U] + def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] /** - * (Scala-specific) - * Returns a new Dataset that contains the result of applying `func` to each partition. + * (Scala-specific) Returns a new Dataset that contains the result of applying `func` to each + * partition. * * @group typedrel * @since 1.6.0 */ - def mapPartitions[U: Encoder](func: Iterator[T] => Iterator[U]): DS[U] + def mapPartitions[U: Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] /** - * (Java-specific) - * Returns a new Dataset that contains the result of applying `f` to each partition. + * (Java-specific) Returns a new Dataset that contains the result of applying `f` to each + * partition. * * @group typedrel * @since 1.6.0 */ - def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): DS[U] + def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = + mapPartitions(ToScalaUDF(f))(encoder) /** - * (Scala-specific) - * Returns a new Dataset by first applying a function to all elements of this Dataset, - * and then flattening the results. + * (Scala-specific) Returns a new Dataset by first applying a function to all elements of this + * Dataset, and then flattening the results. * * @group typedrel * @since 1.6.0 */ - def flatMap[U: Encoder](func: T => IterableOnce[U]): DS[U] = - mapPartitions(_.flatMap(func)) + def flatMap[U: Encoder](func: T => IterableOnce[U]): Dataset[U] = + mapPartitions(UDFAdaptors.flatMapToMapPartitions[T, U](func)) /** - * (Java-specific) - * Returns a new Dataset by first applying a function to all elements of this Dataset, - * and then flattening the results. + * (Java-specific) Returns a new Dataset by first applying a function to all elements of this + * Dataset, and then flattening the results. * * @group typedrel * @since 1.6.0 */ - def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): DS[U] = { - val func: T => Iterator[U] = x => f.call(x).asScala - flatMap(func)(encoder) + def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + mapPartitions(UDFAdaptors.flatMapToMapPartitions(f))(encoder) } /** @@ -2469,16 +2629,19 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group action * @since 1.6.0 */ - def foreach(f: T => Unit): Unit + def foreach(f: T => Unit): Unit = { + foreachPartition(UDFAdaptors.foreachToForeachPartition(f)) + } /** - * (Java-specific) - * Runs `func` on each element of this Dataset. + * (Java-specific) Runs `func` on each element of this Dataset. * * @group action * @since 1.6.0 */ - def foreach(func: ForeachFunction[T]): Unit = foreach(func.call) + def foreach(func: ForeachFunction[T]): Unit = { + foreachPartition(UDFAdaptors.foreachToForeachPartition(func)) + } /** * Applies a function `f` to each partition of this Dataset. @@ -2489,21 +2652,20 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { def foreachPartition(f: Iterator[T] => Unit): Unit /** - * (Java-specific) - * Runs `func` on each partition of this Dataset. + * (Java-specific) Runs `func` on each partition of this Dataset. * * @group action * @since 1.6.0 */ def foreachPartition(func: ForeachPartitionFunction[T]): Unit = { - foreachPartition((it: Iterator[T]) => func.call(it.asJava)) + foreachPartition(ToScalaUDF(func)) } /** * Returns the first `n` rows in the Dataset. * - * Running take requires moving data into the application's driver process, and doing so with - * a very large `n` can crash the driver process with OutOfMemoryError. + * Running take requires moving data into the application's driver process, and doing so with a + * very large `n` can crash the driver process with OutOfMemoryError. * * @group action * @since 1.6.0 @@ -2513,8 +2675,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { /** * Returns the last `n` rows in the Dataset. * - * Running tail requires moving data into the application's driver process, and doing so with - * a very large `n` can crash the driver process with OutOfMemoryError. + * Running tail requires moving data into the application's driver process, and doing so with a + * very large `n` can crash the driver process with OutOfMemoryError. * * @group action * @since 3.0.0 @@ -2524,8 +2686,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { /** * Returns the first `n` rows in the Dataset as a list. * - * Running take requires moving data into the application's driver process, and doing so with - * a very large `n` can crash the driver process with OutOfMemoryError. + * Running take requires moving data into the application's driver process, and doing so with a + * very large `n` can crash the driver process with OutOfMemoryError. * * @group action * @since 1.6.0 @@ -2535,8 +2697,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { /** * Returns an array that contains all rows in this Dataset. * - * Running collect requires moving all the data into the application's driver process, and - * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * Running collect requires moving all the data into the application's driver process, and doing + * so on a very large dataset can crash the driver process with OutOfMemoryError. * * For Java API, use [[collectAsList]]. * @@ -2548,8 +2710,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { /** * Returns a Java list that contains all rows in this Dataset. * - * Running collect requires moving all the data into the application's driver process, and - * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * Running collect requires moving all the data into the application's driver process, and doing + * so on a very large dataset can crash the driver process with OutOfMemoryError. * * @group action * @since 1.6.0 @@ -2561,9 +2723,10 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * * The iterator will consume as much memory as the largest partition in this Dataset. * - * @note this results in multiple Spark jobs, and if the input Dataset is the result - * of a wide transformation (e.g. join with different partitioners), to avoid - * recomputing the input Dataset should be cached first. + * @note + * this results in multiple Spark jobs, and if the input Dataset is the result of a wide + * transformation (e.g. join with different partitioners), to avoid recomputing the input + * Dataset should be cached first. * @group action * @since 2.0.0 */ @@ -2583,15 +2746,15 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def repartition(numPartitions: Int): DS[T] + def repartition(numPartitions: Int): Dataset[T] protected def repartitionByExpression( numPartitions: Option[Int], - partitionExprs: Seq[Column]): DS[T] + partitionExprs: Seq[Column]): Dataset[T] /** - * Returns a new Dataset partitioned by the given partitioning expressions into - * `numPartitions`. The resulting Dataset is hash partitioned. + * Returns a new Dataset partitioned by the given partitioning expressions into `numPartitions`. + * The resulting Dataset is hash partitioned. * * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). * @@ -2599,14 +2762,14 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def repartition(numPartitions: Int, partitionExprs: Column*): DS[T] = { + def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = { repartitionByExpression(Some(numPartitions), partitionExprs) } /** * Returns a new Dataset partitioned by the given partitioning expressions, using - * `spark.sql.shuffle.partitions` as number of partitions. - * The resulting Dataset is hash partitioned. + * `spark.sql.shuffle.partitions` as number of partitions. The resulting Dataset is hash + * partitioned. * * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). * @@ -2614,92 +2777,88 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def repartition(partitionExprs: Column*): DS[T] = { + def repartition(partitionExprs: Column*): Dataset[T] = { repartitionByExpression(None, partitionExprs) } protected def repartitionByRange( numPartitions: Option[Int], - partitionExprs: Seq[Column]): DS[T] - + partitionExprs: Seq[Column]): Dataset[T] /** - * Returns a new Dataset partitioned by the given partitioning expressions into - * `numPartitions`. The resulting Dataset is range partitioned. - * - * At least one partition-by expression must be specified. - * When no explicit sort order is specified, "ascending nulls first" is assumed. - * Note, the rows are not sorted in each partition of the resulting Dataset. + * Returns a new Dataset partitioned by the given partitioning expressions into `numPartitions`. + * The resulting Dataset is range partitioned. * + * At least one partition-by expression must be specified. When no explicit sort order is + * specified, "ascending nulls first" is assumed. Note, the rows are not sorted in each + * partition of the resulting Dataset. * - * Note that due to performance reasons this method uses sampling to estimate the ranges. - * Hence, the output may not be consistent, since sampling can return different values. - * The sample size can be controlled by the config - * `spark.sql.execution.rangeExchange.sampleSizePerPartition`. + * Note that due to performance reasons this method uses sampling to estimate the ranges. Hence, + * the output may not be consistent, since sampling can return different values. The sample size + * can be controlled by the config `spark.sql.execution.rangeExchange.sampleSizePerPartition`. * * @group typedrel * @since 2.3.0 */ @scala.annotation.varargs - def repartitionByRange(numPartitions: Int, partitionExprs: Column*): DS[T] = { + def repartitionByRange(numPartitions: Int, partitionExprs: Column*): Dataset[T] = { repartitionByRange(Some(numPartitions), partitionExprs) } /** * Returns a new Dataset partitioned by the given partitioning expressions, using - * `spark.sql.shuffle.partitions` as number of partitions. - * The resulting Dataset is range partitioned. + * `spark.sql.shuffle.partitions` as number of partitions. The resulting Dataset is range + * partitioned. * - * At least one partition-by expression must be specified. - * When no explicit sort order is specified, "ascending nulls first" is assumed. - * Note, the rows are not sorted in each partition of the resulting Dataset. + * At least one partition-by expression must be specified. When no explicit sort order is + * specified, "ascending nulls first" is assumed. Note, the rows are not sorted in each + * partition of the resulting Dataset. * - * Note that due to performance reasons this method uses sampling to estimate the ranges. - * Hence, the output may not be consistent, since sampling can return different values. - * The sample size can be controlled by the config - * `spark.sql.execution.rangeExchange.sampleSizePerPartition`. + * Note that due to performance reasons this method uses sampling to estimate the ranges. Hence, + * the output may not be consistent, since sampling can return different values. The sample size + * can be controlled by the config `spark.sql.execution.rangeExchange.sampleSizePerPartition`. * * @group typedrel * @since 2.3.0 */ @scala.annotation.varargs - def repartitionByRange(partitionExprs: Column*): DS[T] = { + def repartitionByRange(partitionExprs: Column*): Dataset[T] = { repartitionByRange(None, partitionExprs) } /** * Returns a new Dataset that has exactly `numPartitions` partitions, when the fewer partitions * are requested. If a larger number of partitions is requested, it will stay at the current - * number of partitions. Similar to coalesce defined on an `RDD`, this operation results in - * a narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not - * be a shuffle, instead each of the 100 new partitions will claim 10 of the current partitions. + * number of partitions. Similar to coalesce defined on an `RDD`, this operation results in a + * narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not be a + * shuffle, instead each of the 100 new partitions will claim 10 of the current partitions. * - * However, if you're doing a drastic coalesce, e.g. to numPartitions = 1, - * this may result in your computation taking place on fewer nodes than - * you like (e.g. one node in the case of numPartitions = 1). To avoid this, - * you can call repartition. This will add a shuffle step, but means the - * current upstream partitions will be executed in parallel (per whatever - * the current partitioning is). + * However, if you're doing a drastic coalesce, e.g. to numPartitions = 1, this may result in + * your computation taking place on fewer nodes than you like (e.g. one node in the case of + * numPartitions = 1). To avoid this, you can call repartition. This will add a shuffle step, + * but means the current upstream partitions will be executed in parallel (per whatever the + * current partitioning is). * * @group typedrel * @since 1.6.0 */ - def coalesce(numPartitions: Int): DS[T] + def coalesce(numPartitions: Int): Dataset[T] /** - * Returns a new Dataset that contains only the unique rows from this Dataset. - * This is an alias for `dropDuplicates`. + * Returns a new Dataset that contains only the unique rows from this Dataset. This is an alias + * for `dropDuplicates`. * - * Note that for a streaming [[Dataset]], this method returns distinct rows only once - * regardless of the output mode, which the behavior may not be same with `DISTINCT` in SQL - * against streaming [[Dataset]]. + * Note that for a streaming [[Dataset]], this method returns distinct rows only once regardless + * of the output mode, which the behavior may not be same with `DISTINCT` in SQL against + * streaming [[Dataset]]. * - * @note Equality checking is performed directly on the encoded representation of the data - * and thus is not affected by a custom `equals` function defined on `T`. + * @note + * Equality checking is performed directly on the encoded representation of the data and thus + * is not affected by a custom `equals` function defined on `T`. * @group typedrel * @since 2.0.0 */ - def distinct(): DS[T] = dropDuplicates() + def distinct(): Dataset[T] = dropDuplicates() /** * Persist this Dataset with the default storage level (`MEMORY_AND_DISK`). @@ -2707,7 +2866,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group basic * @since 1.6.0 */ - def persist(): DS[T] + def persist(): Dataset[T] /** * Persist this Dataset with the default storage level (`MEMORY_AND_DISK`). @@ -2715,19 +2874,18 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group basic * @since 1.6.0 */ - def cache(): DS[T] - + def cache(): Dataset[T] /** * Persist this Dataset with the given storage level. * - * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, - * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, - * `MEMORY_AND_DISK_2`, etc. + * @param newLevel + * One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, `MEMORY_AND_DISK_SER`, + * `DISK_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK_2`, etc. * @group basic * @since 1.6.0 */ - def persist(newLevel: StorageLevel): DS[T] + def persist(newLevel: StorageLevel): Dataset[T] /** * Get the Dataset's current storage level, or StorageLevel.NONE if not persisted. @@ -2738,23 +2896,24 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { def storageLevel: StorageLevel /** - * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. - * This will not un-persist any cached data that is built upon this Dataset. + * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. This + * will not un-persist any cached data that is built upon this Dataset. * - * @param blocking Whether to block until all blocks are deleted. + * @param blocking + * Whether to block until all blocks are deleted. * @group basic * @since 1.6.0 */ - def unpersist(blocking: Boolean): DS[T] + def unpersist(blocking: Boolean): Dataset[T] /** - * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. - * This will not un-persist any cached data that is built upon this Dataset. + * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. This + * will not un-persist any cached data that is built upon this Dataset. * * @group basic * @since 1.6.0 */ - def unpersist(): DS[T] + def unpersist(): Dataset[T] /** * Registers this Dataset as a temporary table using the given name. The lifetime of this @@ -2769,14 +2928,15 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { } /** - * Creates a local temporary view using the given name. The lifetime of this - * temporary view is tied to the `SparkSession` that was used to create this Dataset. + * Creates a local temporary view using the given name. The lifetime of this temporary view is + * tied to the `SparkSession` that was used to create this Dataset. * * Local temporary view is session-scoped. Its lifetime is the lifetime of the session that - * created it, i.e. it will be automatically dropped when the session terminates. It's not - * tied to any databases, i.e. we can't use `db1.view1` to reference a local temporary view. + * created it, i.e. it will be automatically dropped when the session terminates. It's not tied + * to any databases, i.e. we can't use `db1.view1` to reference a local temporary view. * - * @throws org.apache.spark.sql.AnalysisException if the view name is invalid or already exists + * @throws org.apache.spark.sql.AnalysisException + * if the view name is invalid or already exists * @group basic * @since 2.0.0 */ @@ -2785,10 +2945,9 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { createTempView(viewName, replace = false, global = false) } - /** - * Creates a local temporary view using the given name. The lifetime of this - * temporary view is tied to the `SparkSession` that was used to create this Dataset. + * Creates a local temporary view using the given name. The lifetime of this temporary view is + * tied to the `SparkSession` that was used to create this Dataset. * * @group basic * @since 2.0.0 @@ -2798,15 +2957,16 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { } /** - * Creates a global temporary view using the given name. The lifetime of this - * temporary view is tied to this Spark application. + * Creates a global temporary view using the given name. The lifetime of this temporary view is + * tied to this Spark application. * - * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark application, - * i.e. it will be automatically dropped when the application terminates. It's tied to a system - * preserved database `global_temp`, and we must use the qualified name to refer a global temp - * view, e.g. `SELECT * FROM global_temp.view1`. + * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark + * application, i.e. it will be automatically dropped when the application terminates. It's tied + * to a system preserved database `global_temp`, and we must use the qualified name to refer a + * global temp view, e.g. `SELECT * FROM global_temp.view1`. * - * @throws org.apache.spark.sql.AnalysisException if the view name is invalid or already exists + * @throws org.apache.spark.sql.AnalysisException + * if the view name is invalid or already exists * @group basic * @since 2.1.0 */ @@ -2819,10 +2979,10 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * Creates or replaces a global temporary view using the given name. The lifetime of this * temporary view is tied to this Spark application. * - * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark application, - * i.e. it will be automatically dropped when the application terminates. It's tied to a system - * preserved database `global_temp`, and we must use the qualified name to refer a global temp - * view, e.g. `SELECT * FROM global_temp.view1`. + * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark + * application, i.e. it will be automatically dropped when the application terminates. It's tied + * to a system preserved database `global_temp`, and we must use the qualified name to refer a + * global temp view, e.g. `SELECT * FROM global_temp.view1`. * * @group basic * @since 2.2.0 @@ -2833,17 +2993,63 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { protected def createTempView(viewName: String, replace: Boolean, global: Boolean): Unit + /** + * Merges a set of updates, insertions, and deletions based on a source table into a target + * table. + * + * Scala Examples: + * {{{ + * spark.table("source") + * .mergeInto("target", $"source.id" === $"target.id") + * .whenMatched($"salary" === 100) + * .delete() + * .whenNotMatched() + * .insertAll() + * .whenNotMatchedBySource($"salary" === 100) + * .update(Map( + * "salary" -> lit(200) + * )) + * .merge() + * }}} + * + * @group basic + * @since 4.0.0 + */ + def mergeInto(table: String, condition: Column): MergeIntoWriter[T] + + /** + * Create a write configuration builder for v2 sources. + * + * This builder is used to configure and execute write operations. For example, to append to an + * existing table, run: + * + * {{{ + * df.writeTo("catalog.db.table").append() + * }}} + * + * This can also be used to create or replace existing tables: + * + * {{{ + * df.writeTo("catalog.db.table").partitionedBy($"col").createOrReplace() + * }}} + * + * @group basic + * @since 3.0.0 + */ + def writeTo(table: String): DataFrameWriterV2[T] + /** * Returns the content of the Dataset as a Dataset of JSON strings. * * @since 2.0.0 */ - def toJSON: DS[String] + def toJSON: Dataset[String] /** * Returns a best-effort snapshot of the files that compose this Dataset. This method simply - * asks each constituent BaseRelation for its respective files and takes the union of all results. - * Depending on the source relations, this may not find all input files. Duplicates are removed. + * asks each constituent BaseRelation for its respective files and takes the union of all + * results. Depending on the source relations, this may not find all input files. Duplicates are + * removed. * * @group basic * @since 2.0.0 @@ -2851,14 +3057,16 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { def inputFiles: Array[String] /** - * Returns `true` when the logical query plans inside both [[Dataset]]s are equal and - * therefore return same results. + * Returns `true` when the logical query plans inside both [[Dataset]]s are equal and therefore + * return same results. * - * @note The equality comparison here is simplified by tolerating the cosmetic differences - * such as attribute names. - * @note This API can compare both [[Dataset]]s very fast but can still return `false` on - * the [[Dataset]] that return the same results, for instance, from different plans. Such - * false negative semantic can be useful when caching as an example. + * @note + * The equality comparison here is simplified by tolerating the cosmetic differences such as + * attribute names. + * @note + * This API can compare both [[Dataset]]s very fast but can still return `false` on the + * [[Dataset]] that return the same results, for instance, from different plans. Such false + * negative semantic can be useful when caching as an example. * @since 3.1.0 */ @DeveloperApi @@ -2867,8 +3075,9 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { /** * Returns a `hashCode` of the logical query plan against this [[Dataset]]. * - * @note Unlike the standard `hashCode`, the hash is calculated against the query plan - * simplified by tolerating the cosmetic differences such as attribute names. + * @note + * Unlike the standard `hashCode`, the hash is calculated against the query plan simplified by + * tolerating the cosmetic differences such as attribute names. * @since 3.1.0 */ @DeveloperApi diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/KeyValueGroupedDataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/KeyValueGroupedDataset.scala new file mode 100644 index 0000000000000..81f999430a128 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/KeyValueGroupedDataset.scala @@ -0,0 +1,1022 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.api + +import org.apache.spark.api.java.function.{CoGroupFunction, FlatMapGroupsFunction, FlatMapGroupsWithStateFunction, MapFunction, MapGroupsFunction, MapGroupsWithStateFunction, ReduceFunction} +import org.apache.spark.sql.{Column, Encoder, TypedColumn} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.PrimitiveLongEncoder +import org.apache.spark.sql.functions.{count => cnt, lit} +import org.apache.spark.sql.internal.{ToScalaUDF, UDFAdaptors} +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeMode} + +/** + * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not + * construct a [[KeyValueGroupedDataset]] directly, but should instead call `groupByKey` on an + * existing [[Dataset]]. + * + * @since 2.0.0 + */ +abstract class KeyValueGroupedDataset[K, V] extends Serializable { + type KVDS[KL, VL] <: KeyValueGroupedDataset[KL, VL] + + /** + * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the + * specified type. The mapping of key columns to the type follows the same rules as `as` on + * [[Dataset]]. + * + * @since 1.6.0 + */ + def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] + + /** + * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to + * the data. The grouping key is unchanged by this. + * + * {{{ + * // Create values grouped by key from a Dataset[(K, V)] + * ds.groupByKey(_._1).mapValues(_._2) // Scala + * }}} + * + * @since 2.1.0 + */ + def mapValues[W: Encoder](func: V => W): KeyValueGroupedDataset[K, W] + + /** + * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to + * the data. The grouping key is unchanged by this. + * + * {{{ + * // Create Integer values grouped by String key from a Dataset> + * Dataset> ds = ...; + * KeyValueGroupedDataset grouped = + * ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT()); + * }}} + * + * @since 2.1.0 + */ + def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = { + mapValues(ToScalaUDF(func))(encoder) + } + + /** + * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping over + * the Dataset to extract the keys and then running a distinct operation on those. + * + * @since 1.6.0 + */ + def keys: Dataset[K] + + /** + * (Scala-specific) Applies the given function to each group of data. For each unique group, the + * function will be passed the group key and an iterator that contains all of the elements in + * the group. The function can return an iterator containing elements of an arbitrary type which + * will be returned as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an + * `org.apache.spark.sql.expressions#Aggregator`. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the + * memory constraints of their cluster. + * + * @since 1.6.0 + */ + def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = { + flatMapSortedGroups(Nil: _*)(f) + } + + /** + * (Java-specific) Applies the given function to each group of data. For each unique group, the + * function will be passed the group key and an iterator that contains all of the elements in + * the group. The function can return an iterator containing elements of an arbitrary type which + * will be returned as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an + * `org.apache.spark.sql.expressions#Aggregator`. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the + * memory constraints of their cluster. + * + * @since 1.6.0 + */ + def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + flatMapGroups(ToScalaUDF(f))(encoder) + } + + /** + * (Scala-specific) Applies the given function to each group of data. For each unique group, the + * function will be passed the group key and a sorted iterator that contains all of the elements + * in the group. The function can return an iterator containing elements of an arbitrary type + * which will be returned as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an + * `org.apache.spark.sql.expressions#Aggregator`. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the + * memory constraints of their cluster. + * + * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be + * sorted according to the given sort expressions. That sorting does not add computational + * complexity. + * + * @see + * `org.apache.spark.sql.api.KeyValueGroupedDataset#flatMapGroups` + * @since 3.4.0 + */ + def flatMapSortedGroups[U: Encoder](sortExprs: Column*)( + f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] + + /** + * (Java-specific) Applies the given function to each group of data. For each unique group, the + * function will be passed the group key and a sorted iterator that contains all of the elements + * in the group. The function can return an iterator containing elements of an arbitrary type + * which will be returned as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an + * `org.apache.spark.sql.expressions#Aggregator`. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the + * memory constraints of their cluster. + * + * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be + * sorted according to the given sort expressions. That sorting does not add computational + * complexity. + * + * @see + * `org.apache.spark.sql.api.KeyValueGroupedDataset#flatMapGroups` + * @since 3.4.0 + */ + def flatMapSortedGroups[U]( + SortExprs: Array[Column], + f: FlatMapGroupsFunction[K, V, U], + encoder: Encoder[U]): Dataset[U] = { + import org.apache.spark.util.ArrayImplicits._ + flatMapSortedGroups(SortExprs.toImmutableArraySeq: _*)(ToScalaUDF(f))(encoder) + } + + /** + * (Scala-specific) Applies the given function to each group of data. For each unique group, the + * function will be passed the group key and an iterator that contains all of the elements in + * the group. The function can return an element of arbitrary type which will be returned as a + * new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an + * `org.apache.spark.sql.expressions#Aggregator`. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the + * memory constraints of their cluster. + * + * @since 1.6.0 + */ + def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { + flatMapGroups(UDFAdaptors.mapGroupsToFlatMapGroups(f)) + } + + /** + * (Java-specific) Applies the given function to each group of data. For each unique group, the + * function will be passed the group key and an iterator that contains all of the elements in + * the group. The function can return an element of arbitrary type which will be returned as a + * new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an + * `org.apache.spark.sql.expressions#Aggregator`. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the + * memory constraints of their cluster. + * + * @since 1.6.0 + */ + def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + mapGroups(ToScalaUDF(f))(encoder) + } + + /** + * (Scala-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See + * [[org.apache.spark.sql.streaming.GroupState]] for more details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * + * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark + * SQL. + * @since 2.2.0 + */ + def mapGroupsWithState[S: Encoder, U: Encoder]( + func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] + + /** + * (Scala-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See + * [[org.apache.spark.sql.streaming.GroupState]] for more details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * @param timeoutConf + * Timeout configuration for groups that do not receive data for a while. + * + * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark + * SQL. + * @since 2.2.0 + */ + def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)( + func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] + + /** + * (Scala-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See + * [[org.apache.spark.sql.streaming.GroupState]] for more details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * @param timeoutConf + * Timeout Conf, see GroupStateTimeout for more details + * @param initialState + * The user provided state that will be initialized when the first batch of data is processed + * in the streaming query. The user defined function will be called on the state data even if + * there are no other values in the group. To convert a Dataset ds of type Dataset[(K, S)] to + * a KeyValueGroupedDataset[K, S] do {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} + * + * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark + * SQL. + * @since 3.2.0 + */ + def mapGroupsWithState[S: Encoder, U: Encoder]( + timeoutConf: GroupStateTimeout, + initialState: KVDS[K, S])(func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] + + /** + * (Java-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See `GroupState` for more + * details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * @param stateEncoder + * Encoder for the state type. + * @param outputEncoder + * Encoder for the output type. + * + * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark + * SQL. + * @since 2.2.0 + */ + def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U]): Dataset[U] = { + mapGroupsWithState[S, U](ToScalaUDF(func))(stateEncoder, outputEncoder) + } + + /** + * (Java-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See `GroupState` for more + * details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * @param stateEncoder + * Encoder for the state type. + * @param outputEncoder + * Encoder for the output type. + * @param timeoutConf + * Timeout configuration for groups that do not receive data for a while. + * + * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark + * SQL. + * @since 2.2.0 + */ + def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout): Dataset[U] = { + mapGroupsWithState[S, U](timeoutConf)(ToScalaUDF(func))(stateEncoder, outputEncoder) + } + + /** + * (Java-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See `GroupState` for more + * details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * @param stateEncoder + * Encoder for the state type. + * @param outputEncoder + * Encoder for the output type. + * @param timeoutConf + * Timeout configuration for groups that do not receive data for a while. + * @param initialState + * The user provided state that will be initialized when the first batch of data is processed + * in the streaming query. The user defined function will be called on the state data even if + * there are no other values in the group. + * + * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark + * SQL. + * @since 3.2.0 + */ + def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout, + initialState: KVDS[K, S]): Dataset[U] = { + val f = ToScalaUDF(func) + mapGroupsWithState[S, U](timeoutConf, initialState)(f)(stateEncoder, outputEncoder) + } + + /** + * (Scala-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See `GroupState` for more + * details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * @param outputMode + * The output mode of the function. + * @param timeoutConf + * Timeout configuration for groups that do not receive data for a while. + * + * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark + * SQL. + * @since 2.2.0 + */ + def flatMapGroupsWithState[S: Encoder, U: Encoder]( + outputMode: OutputMode, + timeoutConf: GroupStateTimeout)( + func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] + + /** + * (Scala-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See `GroupState` for more + * details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * @param outputMode + * The output mode of the function. + * @param timeoutConf + * Timeout configuration for groups that do not receive data for a while. + * @param initialState + * The user provided state that will be initialized when the first batch of data is processed + * in the streaming query. The user defined function will be called on the state data even if + * there are no other values in the group. To covert a Dataset `ds` of type of type + * `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use + * {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} See [[org.apache.spark.sql.Encoder]] for + * more details on what types are encodable to Spark SQL. + * @since 3.2.0 + */ + def flatMapGroupsWithState[S: Encoder, U: Encoder]( + outputMode: OutputMode, + timeoutConf: GroupStateTimeout, + initialState: KVDS[K, S])(func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] + + /** + * (Java-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See `GroupState` for more + * details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * @param outputMode + * The output mode of the function. + * @param stateEncoder + * Encoder for the state type. + * @param outputEncoder + * Encoder for the output type. + * @param timeoutConf + * Timeout configuration for groups that do not receive data for a while. + * + * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark + * SQL. + * @since 2.2.0 + */ + def flatMapGroupsWithState[S, U]( + func: FlatMapGroupsWithStateFunction[K, V, S, U], + outputMode: OutputMode, + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout): Dataset[U] = { + val f = ToScalaUDF(func) + flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder) + } + + /** + * (Java-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See `GroupState` for more + * details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * @param outputMode + * The output mode of the function. + * @param stateEncoder + * Encoder for the state type. + * @param outputEncoder + * Encoder for the output type. + * @param timeoutConf + * Timeout configuration for groups that do not receive data for a while. + * @param initialState + * The user provided state that will be initialized when the first batch of data is processed + * in the streaming query. The user defined function will be called on the state data even if + * there are no other values in the group. To covert a Dataset `ds` of type of type + * `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use + * {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} + * + * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark + * SQL. + * @since 3.2.0 + */ + def flatMapGroupsWithState[S, U]( + func: FlatMapGroupsWithStateFunction[K, V, S, U], + outputMode: OutputMode, + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout, + initialState: KVDS[K, S]): Dataset[U] = { + flatMapGroupsWithState[S, U](outputMode, timeoutConf, initialState)(ToScalaUDF(func))( + stateEncoder, + outputEncoder) + } + + /** + * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state + * API v2. We allow the user to act on per-group set of input rows along with keyed state and + * the user can choose to output/return 0 or more rows. For a streaming dataframe, we will + * repeatedly invoke the interface methods for new rows in each trigger and the user's + * state/state variables will be stored persistently across invocations. + * + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param statefulProcessor + * Instance of statefulProcessor whose functions will be invoked by the operator. + * @param timeMode + * The time mode semantics of the stateful processor for timers and TTL. + * @param outputMode + * The output mode of the stateful processor. + * + * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark + * SQL. + */ + private[sql] def transformWithState[U: Encoder]( + statefulProcessor: StatefulProcessor[K, V, U], + timeMode: TimeMode, + outputMode: OutputMode): Dataset[U] + + /** + * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state + * API v2. We allow the user to act on per-group set of input rows along with keyed state and + * the user can choose to output/return 0 or more rows. For a streaming dataframe, we will + * repeatedly invoke the interface methods for new rows in each trigger and the user's + * state/state variables will be stored persistently across invocations. + * + * Downstream operators would use specified eventTimeColumnName to calculate watermark. Note + * that TimeMode is set to EventTime to ensure correct flow of watermark. + * + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param statefulProcessor + * Instance of statefulProcessor whose functions will be invoked by the operator. + * @param eventTimeColumnName + * eventTime column in the output dataset. Any operations after transformWithState will use + * the new eventTimeColumn. The user needs to ensure that the eventTime for emitted output + * adheres to the watermark boundary, otherwise streaming query will fail. + * @param outputMode + * The output mode of the stateful processor. + * + * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark + * SQL. + */ + private[sql] def transformWithState[U: Encoder]( + statefulProcessor: StatefulProcessor[K, V, U], + eventTimeColumnName: String, + outputMode: OutputMode): Dataset[U] + + /** + * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API + * v2. We allow the user to act on per-group set of input rows along with keyed state and the + * user can choose to output/return 0 or more rows. For a streaming dataframe, we will + * repeatedly invoke the interface methods for new rows in each trigger and the user's + * state/state variables will be stored persistently across invocations. + * + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param statefulProcessor + * Instance of statefulProcessor whose functions will be invoked by the operator. + * @param timeMode + * The time mode semantics of the stateful processor for timers and TTL. + * @param outputMode + * The output mode of the stateful processor. + * @param outputEncoder + * Encoder for the output type. + * + * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark + * SQL. + */ + private[sql] def transformWithState[U: Encoder]( + statefulProcessor: StatefulProcessor[K, V, U], + timeMode: TimeMode, + outputMode: OutputMode, + outputEncoder: Encoder[U]): Dataset[U] = { + transformWithState(statefulProcessor, timeMode, outputMode)(outputEncoder) + } + + /** + * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API + * v2. We allow the user to act on per-group set of input rows along with keyed state and the + * user can choose to output/return 0 or more rows. + * + * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows in + * each trigger and the user's state/state variables will be stored persistently across + * invocations. + * + * Downstream operators would use specified eventTimeColumnName to calculate watermark. Note + * that TimeMode is set to EventTime to ensure correct flow of watermark. + * + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param statefulProcessor + * Instance of statefulProcessor whose functions will be invoked by the operator. + * @param eventTimeColumnName + * eventTime column in the output dataset. Any operations after transformWithState will use + * the new eventTimeColumn. The user needs to ensure that the eventTime for emitted output + * adheres to the watermark boundary, otherwise streaming query will fail. + * @param outputMode + * The output mode of the stateful processor. + * @param outputEncoder + * Encoder for the output type. + * + * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark + * SQL. + */ + private[sql] def transformWithState[U: Encoder]( + statefulProcessor: StatefulProcessor[K, V, U], + eventTimeColumnName: String, + outputMode: OutputMode, + outputEncoder: Encoder[U]): Dataset[U] = { + transformWithState(statefulProcessor, eventTimeColumnName, outputMode)(outputEncoder) + } + + /** + * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state + * API v2. Functions as the function above, but with additional initial state. + * + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @tparam S + * The type of initial state objects. Must be encodable to Spark SQL types. + * @param statefulProcessor + * Instance of statefulProcessor whose functions will be invoked by the operator. + * @param timeMode + * The time mode semantics of the stateful processor for timers and TTL. + * @param outputMode + * The output mode of the stateful processor. + * @param initialState + * User provided initial state that will be used to initiate state for the query in the first + * batch. + * + * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark + * SQL. + */ + private[sql] def transformWithState[U: Encoder, S: Encoder]( + statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], + timeMode: TimeMode, + outputMode: OutputMode, + initialState: KVDS[K, S]): Dataset[U] + + /** + * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state + * API v2. Functions as the function above, but with additional eventTimeColumnName for output. + * + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @tparam S + * The type of initial state objects. Must be encodable to Spark SQL types. + * + * Downstream operators would use specified eventTimeColumnName to calculate watermark. Note + * that TimeMode is set to EventTime to ensure correct flow of watermark. + * + * @param statefulProcessor + * Instance of statefulProcessor whose functions will be invoked by the operator. + * @param eventTimeColumnName + * eventTime column in the output dataset. Any operations after transformWithState will use + * the new eventTimeColumn. The user needs to ensure that the eventTime for emitted output + * adheres to the watermark boundary, otherwise streaming query will fail. + * @param outputMode + * The output mode of the stateful processor. + * @param initialState + * User provided initial state that will be used to initiate state for the query in the first + * batch. + * + * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark + * SQL. + */ + private[sql] def transformWithState[U: Encoder, S: Encoder]( + statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], + eventTimeColumnName: String, + outputMode: OutputMode, + initialState: KVDS[K, S]): Dataset[U] + + /** + * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API + * v2. Functions as the function above, but with additional initialStateEncoder for state + * encoding. + * + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @tparam S + * The type of initial state objects. Must be encodable to Spark SQL types. + * @param statefulProcessor + * Instance of statefulProcessor whose functions will be invoked by the operator. + * @param timeMode + * The time mode semantics of the stateful processor for timers and TTL. + * @param outputMode + * The output mode of the stateful processor. + * @param initialState + * User provided initial state that will be used to initiate state for the query in the first + * batch. + * @param outputEncoder + * Encoder for the output type. + * @param initialStateEncoder + * Encoder for the initial state type. + * + * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark + * SQL. + */ + private[sql] def transformWithState[U: Encoder, S: Encoder]( + statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], + timeMode: TimeMode, + outputMode: OutputMode, + initialState: KVDS[K, S], + outputEncoder: Encoder[U], + initialStateEncoder: Encoder[S]): Dataset[U] = { + transformWithState(statefulProcessor, timeMode, outputMode, initialState)( + outputEncoder, + initialStateEncoder) + } + + /** + * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API + * v2. Functions as the function above, but with additional eventTimeColumnName for output. + * + * Downstream operators would use specified eventTimeColumnName to calculate watermark. Note + * that TimeMode is set to EventTime to ensure correct flow of watermark. + * + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @tparam S + * The type of initial state objects. Must be encodable to Spark SQL types. + * @param statefulProcessor + * Instance of statefulProcessor whose functions will be invoked by the operator. + * @param outputMode + * The output mode of the stateful processor. + * @param initialState + * User provided initial state that will be used to initiate state for the query in the first + * batch. + * @param eventTimeColumnName + * event column in the output dataset. Any operations after transformWithState will use the + * new eventTimeColumn. The user needs to ensure that the eventTime for emitted output adheres + * to the watermark boundary, otherwise streaming query will fail. + * @param outputEncoder + * Encoder for the output type. + * @param initialStateEncoder + * Encoder for the initial state type. + * + * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark + * SQL. + */ + private[sql] def transformWithState[U: Encoder, S: Encoder]( + statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], + outputMode: OutputMode, + initialState: KVDS[K, S], + eventTimeColumnName: String, + outputEncoder: Encoder[U], + initialStateEncoder: Encoder[S]): Dataset[U] = { + transformWithState(statefulProcessor, eventTimeColumnName, outputMode, initialState)( + outputEncoder, + initialStateEncoder) + } + + /** + * (Scala-specific) Reduces the elements of each group of data using the specified binary + * function. The given function must be commutative and associative or the result may be + * non-deterministic. + * + * @since 1.6.0 + */ + def reduceGroups(f: (V, V) => V): Dataset[(K, V)] + + /** + * (Java-specific) Reduces the elements of each group of data using the specified binary + * function. The given function must be commutative and associative or the result may be + * non-deterministic. + * + * @since 1.6.0 + */ + def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = { + reduceGroups(ToScalaUDF(f)) + } + + /** + * Internal helper function for building typed aggregations that return tuples. For simplicity + * and code reuse, we do this without the help of the type system and then use helper functions + * that cast appropriately for the user facing interface. + */ + protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] + + /** + * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key and the + * result of computing this aggregation over all elements in the group. + * + * @since 1.6.0 + */ + def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = + aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and + * the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 + */ + def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = + aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and + * the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 + */ + def agg[U1, U2, U3]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = + aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and + * the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 + */ + def agg[U1, U2, U3, U4]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = + aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and + * the result of computing these aggregations over all elements in the group. + * + * @since 3.0.0 + */ + def agg[U1, U2, U3, U4, U5]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] = + aggUntyped(col1, col2, col3, col4, col5).asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and + * the result of computing these aggregations over all elements in the group. + * + * @since 3.0.0 + */ + def agg[U1, U2, U3, U4, U5, U6]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] = + aggUntyped(col1, col2, col3, col4, col5, col6) + .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and + * the result of computing these aggregations over all elements in the group. + * + * @since 3.0.0 + */ + def agg[U1, U2, U3, U4, U5, U6, U7]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6], + col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] = + aggUntyped(col1, col2, col3, col4, col5, col6, col7) + .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and + * the result of computing these aggregations over all elements in the group. + * + * @since 3.0.0 + */ + def agg[U1, U2, U3, U4, U5, U6, U7, U8]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6], + col7: TypedColumn[V, U7], + col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] = + aggUntyped(col1, col2, col3, col4, col5, col6, col7, col8) + .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)]] + + /** + * Returns a [[Dataset]] that contains a tuple with each key and the number of items present for + * that key. + * + * @since 1.6.0 + */ + def count(): Dataset[(K, Long)] = agg(cnt(lit(1)).as(PrimitiveLongEncoder)) + + /** + * (Scala-specific) Applies the given function to each cogrouped data. For each unique group, + * the function will be passed the grouping key and 2 iterators containing all elements in the + * group from [[Dataset]] `this` and `other`. The function can return an iterator containing + * elements of an arbitrary type which will be returned as a new [[Dataset]]. + * + * @since 1.6.0 + */ + def cogroup[U, R: Encoder](other: KVDS[K, U])( + f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = { + cogroupSorted(other)(Nil: _*)(Nil: _*)(f) + } + + /** + * (Java-specific) Applies the given function to each cogrouped data. For each unique group, the + * function will be passed the grouping key and 2 iterators containing all elements in the group + * from [[Dataset]] `this` and `other`. The function can return an iterator containing elements + * of an arbitrary type which will be returned as a new [[Dataset]]. + * + * @since 1.6.0 + */ + def cogroup[U, R]( + other: KVDS[K, U], + f: CoGroupFunction[K, V, U, R], + encoder: Encoder[R]): Dataset[R] = { + cogroup(other)(ToScalaUDF(f))(encoder) + } + + /** + * (Scala-specific) Applies the given function to each sorted cogrouped data. For each unique + * group, the function will be passed the grouping key and 2 sorted iterators containing all + * elements in the group from [[Dataset]] `this` and `other`. The function can return an + * iterator containing elements of an arbitrary type which will be returned as a new + * [[Dataset]]. + * + * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be + * sorted according to the given sort expressions. That sorting does not add computational + * complexity. + * + * @see + * `org.apache.spark.sql.api.KeyValueGroupedDataset#cogroup` + * @since 3.4.0 + */ + def cogroupSorted[U, R: Encoder](other: KVDS[K, U])(thisSortExprs: Column*)( + otherSortExprs: Column*)(f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] + + /** + * (Java-specific) Applies the given function to each sorted cogrouped data. For each unique + * group, the function will be passed the grouping key and 2 sorted iterators containing all + * elements in the group from [[Dataset]] `this` and `other`. The function can return an + * iterator containing elements of an arbitrary type which will be returned as a new + * [[Dataset]]. + * + * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be + * sorted according to the given sort expressions. That sorting does not add computational + * complexity. + * + * @see + * `org.apache.spark.sql.api.KeyValueGroupedDataset#cogroup` + * @since 3.4.0 + */ + def cogroupSorted[U, R]( + other: KVDS[K, U], + thisSortExprs: Array[Column], + otherSortExprs: Array[Column], + f: CoGroupFunction[K, V, U, R], + encoder: Encoder[R]): Dataset[R] = { + import org.apache.spark.util.ArrayImplicits._ + cogroupSorted(other)(thisSortExprs.toImmutableArraySeq: _*)( + otherSortExprs.toImmutableArraySeq: _*)(ToScalaUDF(f))(encoder) + } +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala index 30b2992d43a00..118b8f1ecd488 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala @@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters._ import _root_.java.util import org.apache.spark.annotation.Stable -import org.apache.spark.sql.{functions, Column, Row} +import org.apache.spark.sql.{functions, Column, Encoder, Row} /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -30,19 +30,18 @@ import org.apache.spark.sql.{functions, Column, Row} * The main method is the `agg` function, which has multiple variants. This class also contains * some first-order statistics such as `mean`, `sum` for convenience. * - * @note This class was named `GroupedData` in Spark 1.x. + * @note + * This class was named `GroupedData` in Spark 1.x. * @since 2.0.0 */ @Stable -abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { - type RGD <: RelationalGroupedDataset[DS] - - protected def df: DS[Row] +abstract class RelationalGroupedDataset { + protected def df: Dataset[Row] /** * Create a aggregation based on the grouping column, the grouping type, and the aggregations. */ - protected def toDF(aggCols: Seq[Column]): DS[Row] + protected def toDF(aggCols: Seq[Column]): Dataset[Row] protected def selectNumericColumns(colNames: Seq[String]): Seq[Column] @@ -61,13 +60,21 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { private def aggregateNumericColumns( colNames: Seq[String], - function: Column => Column): DS[Row] = { + function: Column => Column): Dataset[Row] = { toDF(selectNumericColumns(colNames).map(function)) } /** - * (Scala-specific) Compute aggregates by specifying the column names and - * aggregate methods. The resulting `DataFrame` will also contain the grouping columns. + * Returns a `KeyValueGroupedDataset` where the data is grouped by the grouping expressions of + * current `RelationalGroupedDataset`. + * + * @since 3.0.0 + */ + def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] + + /** + * (Scala-specific) Compute aggregates by specifying the column names and aggregate methods. The + * resulting `DataFrame` will also contain the grouping columns. * * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. * {{{ @@ -80,12 +87,12 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * * @since 1.3.0 */ - def agg(aggExpr: (String, String), aggExprs: (String, String)*): DS[Row] = + def agg(aggExpr: (String, String), aggExprs: (String, String)*): Dataset[Row] = toDF((aggExpr +: aggExprs).map(toAggCol)) /** - * (Scala-specific) Compute aggregates by specifying a map from column name to - * aggregate methods. The resulting `DataFrame` will also contain the grouping columns. + * (Scala-specific) Compute aggregates by specifying a map from column name to aggregate + * methods. The resulting `DataFrame` will also contain the grouping columns. * * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. * {{{ @@ -98,11 +105,11 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * * @since 1.3.0 */ - def agg(exprs: Map[String, String]): DS[Row] = toDF(exprs.map(toAggCol).toSeq) + def agg(exprs: Map[String, String]): Dataset[Row] = toDF(exprs.map(toAggCol).toSeq) /** - * (Java-specific) Compute aggregates by specifying a map from column name to - * aggregate methods. The resulting `DataFrame` will also contain the grouping columns. + * (Java-specific) Compute aggregates by specifying a map from column name to aggregate methods. + * The resulting `DataFrame` will also contain the grouping columns. * * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. * {{{ @@ -113,7 +120,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * * @since 1.3.0 */ - def agg(exprs: util.Map[String, String]): DS[Row] = { + def agg(exprs: util.Map[String, String]): Dataset[Row] = { agg(exprs.asScala.toMap) } @@ -149,91 +156,92 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * @since 1.3.0 */ @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DS[Row] = toDF(expr +: exprs) + def agg(expr: Column, exprs: Column*): Dataset[Row] = toDF(expr +: exprs) /** - * Count the number of rows for each group. - * The resulting `DataFrame` will also contain the grouping columns. + * Count the number of rows for each group. The resulting `DataFrame` will also contain the + * grouping columns. * * @since 1.3.0 */ - def count(): DS[Row] = toDF(functions.count(functions.lit(1)).as("count") :: Nil) + def count(): Dataset[Row] = toDF(functions.count(functions.lit(1)).as("count") :: Nil) /** - * Compute the average value for each numeric columns for each group. This is an alias for `avg`. - * The resulting `DataFrame` will also contain the grouping columns. - * When specified columns are given, only compute the average values for them. + * Compute the average value for each numeric columns for each group. This is an alias for + * `avg`. The resulting `DataFrame` will also contain the grouping columns. When specified + * columns are given, only compute the average values for them. * * @since 1.3.0 */ @scala.annotation.varargs - def mean(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.avg) - + def mean(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.avg) /** - * Compute the max value for each numeric columns for each group. - * The resulting `DataFrame` will also contain the grouping columns. - * When specified columns are given, only compute the max values for them. + * Compute the max value for each numeric columns for each group. The resulting `DataFrame` will + * also contain the grouping columns. When specified columns are given, only compute the max + * values for them. * * @since 1.3.0 */ @scala.annotation.varargs - def max(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.max) + def max(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.max) /** - * Compute the mean value for each numeric columns for each group. - * The resulting `DataFrame` will also contain the grouping columns. - * When specified columns are given, only compute the mean values for them. + * Compute the mean value for each numeric columns for each group. The resulting `DataFrame` + * will also contain the grouping columns. When specified columns are given, only compute the + * mean values for them. * * @since 1.3.0 */ @scala.annotation.varargs - def avg(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.avg) + def avg(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.avg) /** - * Compute the min value for each numeric column for each group. - * The resulting `DataFrame` will also contain the grouping columns. - * When specified columns are given, only compute the min values for them. + * Compute the min value for each numeric column for each group. The resulting `DataFrame` will + * also contain the grouping columns. When specified columns are given, only compute the min + * values for them. * * @since 1.3.0 */ @scala.annotation.varargs - def min(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.min) + def min(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.min) /** - * Compute the sum for each numeric columns for each group. - * The resulting `DataFrame` will also contain the grouping columns. - * When specified columns are given, only compute the sum for them. + * Compute the sum for each numeric columns for each group. The resulting `DataFrame` will also + * contain the grouping columns. When specified columns are given, only compute the sum for + * them. * * @since 1.3.0 */ @scala.annotation.varargs - def sum(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.sum) + def sum(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.sum) /** * Pivots a column of the current `DataFrame` and performs the specified aggregation. * - * Spark will eagerly compute the distinct values in `pivotColumn` so it can determine - * the resulting schema of the transformation. To avoid any eager computations, provide an - * explicit list of values via `pivot(pivotColumn: String, values: Seq[Any])`. + * Spark will eagerly compute the distinct values in `pivotColumn` so it can determine the + * resulting schema of the transformation. To avoid any eager computations, provide an explicit + * list of values via `pivot(pivotColumn: String, values: Seq[Any])`. * * {{{ * // Compute the sum of earnings for each year by course with each course as a separate column * df.groupBy("year").pivot("course").sum("earnings") * }}} * - * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, - * except for the aggregation. - * @param pivotColumn Name of the column to pivot. + * @see + * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the + * aggregation. + * @param pivotColumn + * Name of the column to pivot. * @since 1.6.0 */ - def pivot(pivotColumn: String): RGD = pivot(df.col(pivotColumn)) + def pivot(pivotColumn: String): RelationalGroupedDataset = pivot(df.col(pivotColumn)) /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * There are two versions of pivot function: one that requires the caller to specify the list - * of distinct values to pivot on, and one that does not. The latter is more concise but less - * efficient, because Spark needs to first compute the list of distinct values internally. + * Pivots a column of the current `DataFrame` and performs the specified aggregation. There are + * two versions of pivot function: one that requires the caller to specify the list of distinct + * values to pivot on, and one that does not. The latter is more concise but less efficient, + * because Spark needs to first compute the list of distinct values internally. * * {{{ * // Compute the sum of earnings for each year by course with each course as a separate column @@ -252,21 +260,24 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * .agg(sum($"earnings")) * }}} * - * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, - * except for the aggregation. - * @param pivotColumn Name of the column to pivot. - * @param values List of values that will be translated to columns in the output DataFrame. + * @see + * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the + * aggregation. + * @param pivotColumn + * Name of the column to pivot. + * @param values + * List of values that will be translated to columns in the output DataFrame. * @since 1.6.0 */ - def pivot(pivotColumn: String, values: Seq[Any]): RGD = + def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = pivot(df.col(pivotColumn), values) /** * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified * aggregation. * - * There are two versions of pivot function: one that requires the caller to specify the list - * of distinct values to pivot on, and one that does not. The latter is more concise but less + * There are two versions of pivot function: one that requires the caller to specify the list of + * distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. * * {{{ @@ -277,62 +288,73 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * df.groupBy("year").pivot("course").sum("earnings"); * }}} * - * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, - * except for the aggregation. - * @param pivotColumn Name of the column to pivot. - * @param values List of values that will be translated to columns in the output DataFrame. + * @see + * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the + * aggregation. + * @param pivotColumn + * Name of the column to pivot. + * @param values + * List of values that will be translated to columns in the output DataFrame. * @since 1.6.0 */ - def pivot(pivotColumn: String, values: util.List[Any]): RGD = + def pivot(pivotColumn: String, values: util.List[Any]): RelationalGroupedDataset = pivot(df.col(pivotColumn), values) /** * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified - * aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of - * the `String` type. - * - * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, - * except for the aggregation. - * @param pivotColumn the column to pivot. - * @param values List of values that will be translated to columns in the output DataFrame. + * aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of the + * `String` type. + * + * @see + * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the + * aggregation. + * @param pivotColumn + * the column to pivot. + * @param values + * List of values that will be translated to columns in the output DataFrame. * @since 2.4.0 */ - def pivot(pivotColumn: Column, values: util.List[Any]): RGD = + def pivot(pivotColumn: Column, values: util.List[Any]): RelationalGroupedDataset = pivot(pivotColumn, values.asScala.toSeq) /** * Pivots a column of the current `DataFrame` and performs the specified aggregation. * - * Spark will eagerly compute the distinct values in `pivotColumn` so it can determine - * the resulting schema of the transformation. To avoid any eager computations, provide an - * explicit list of values via `pivot(pivotColumn: Column, values: Seq[Any])`. + * Spark will eagerly compute the distinct values in `pivotColumn` so it can determine the + * resulting schema of the transformation. To avoid any eager computations, provide an explicit + * list of values via `pivot(pivotColumn: Column, values: Seq[Any])`. * * {{{ * // Compute the sum of earnings for each year by course with each course as a separate column * df.groupBy($"year").pivot($"course").sum($"earnings"); * }}} * - * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, - * except for the aggregation. - * @param pivotColumn he column to pivot. + * @see + * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the + * aggregation. + * @param pivotColumn + * he column to pivot. * @since 2.4.0 */ - def pivot(pivotColumn: Column): RGD + def pivot(pivotColumn: Column): RelationalGroupedDataset /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type. + * Pivots a column of the current `DataFrame` and performs the specified aggregation. This is an + * overloaded version of the `pivot` method with `pivotColumn` of the `String` type. * * {{{ * // Compute the sum of earnings for each year by course with each course as a separate column * df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings") * }}} * - * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, - * except for the aggregation. - * @param pivotColumn the column to pivot. - * @param values List of values that will be translated to columns in the output DataFrame. + * @see + * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the + * aggregation. + * @param pivotColumn + * the column to pivot. + * @param values + * List of values that will be translated to columns in the output DataFrame. * @since 2.4.0 */ - def pivot(pivotColumn: Column, values: Seq[Any]): RGD + def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala index d156aba934b68..41d16b16ab1c5 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala @@ -26,14 +26,14 @@ import _root_.java.net.URI import _root_.java.util import org.apache.spark.annotation.{DeveloperApi, Experimental} -import org.apache.spark.sql.{Encoder, Row} +import org.apache.spark.sql.{Encoder, Row, RuntimeConfig} import org.apache.spark.sql.types.StructType /** * The entry point to programming Spark with the Dataset and DataFrame API. * - * In environments that this has been created upfront (e.g. REPL, notebooks), use the builder - * to get an existing session: + * In environments that this has been created upfront (e.g. REPL, notebooks), use the builder to + * get an existing session: * * {{{ * SparkSession.builder().getOrCreate() @@ -49,7 +49,8 @@ import org.apache.spark.sql.types.StructType * .getOrCreate() * }}} */ -abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with Closeable { +abstract class SparkSession extends Serializable with Closeable { + /** * The version of Spark on which this application is running. * @@ -57,6 +58,17 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C */ def version: String + /** + * Runtime configuration interface for Spark. + * + * This is the interface through which the user can get and set all Spark and Hadoop + * configurations that are relevant to Spark SQL. When getting the value of a config, this + * defaults to the value set in the underlying `SparkContext`, if any. + * + * @since 2.0.0 + */ + def conf: RuntimeConfig + /** * A collection of methods for registering user-defined functions (UDF). * @@ -72,24 +84,26 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * DataTypes.StringType); * }}} * - * @note The user-defined functions must be deterministic. Due to optimization, - * duplicate invocations may be eliminated or the function may even be invoked more times - * than it is present in the query. + * @note + * The user-defined functions must be deterministic. Due to optimization, duplicate + * invocations may be eliminated or the function may even be invoked more times than it is + * present in the query. * @since 2.0.0 */ def udf: UDFRegistration /** - * Start a new session with isolated SQL configurations, temporary tables, registered - * functions are isolated, but sharing the underlying `SparkContext` and cached data. - * - * @note Other than the `SparkContext`, all shared state is initialized lazily. - * This method will force the initialization of the shared state to ensure that parent - * and child sessions are set up with the same shared state. If the underlying catalog - * implementation is Hive, this will initialize the metastore, which may take some time. + * Start a new session with isolated SQL configurations, temporary tables, registered functions + * are isolated, but sharing the underlying `SparkContext` and cached data. + * + * @note + * Other than the `SparkContext`, all shared state is initialized lazily. This method will + * force the initialization of the shared state to ensure that parent and child sessions are + * set up with the same shared state. If the underlying catalog implementation is Hive, this + * will initialize the metastore, which may take some time. * @since 2.0.0 */ - def newSession(): SparkSession[DS] + def newSession(): SparkSession /* --------------------------------- * | Methods for creating DataFrames | @@ -101,36 +115,35 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * @since 2.0.0 */ @transient - def emptyDataFrame: DS[Row] + def emptyDataFrame: Dataset[Row] /** * Creates a `DataFrame` from a local Seq of Product. * * @since 2.0.0 */ - def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DS[Row] + def createDataFrame[A <: Product: TypeTag](data: Seq[A]): Dataset[Row] /** - * :: DeveloperApi :: - * Creates a `DataFrame` from a `java.util.List` containing [[org.apache.spark.sql.Row]]s using - * the given schema.It is important to make sure that the structure of every - * [[org.apache.spark.sql.Row]] of the provided List matches the provided schema. Otherwise, - * there will be runtime exception. + * :: DeveloperApi :: Creates a `DataFrame` from a `java.util.List` containing + * [[org.apache.spark.sql.Row]]s using the given schema.It is important to make sure that the + * structure of every [[org.apache.spark.sql.Row]] of the provided List matches the provided + * schema. Otherwise, there will be runtime exception. * * @since 2.0.0 */ @DeveloperApi - def createDataFrame(rows: util.List[Row], schema: StructType): DS[Row] + def createDataFrame(rows: util.List[Row], schema: StructType): Dataset[Row] /** * Applies a schema to a List of Java Beans. * - * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, - * SELECT * queries will return the columns in an undefined order. + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, SELECT * queries + * will return the columns in an undefined order. * * @since 1.6.0 */ - def createDataFrame(data: util.List[_], beanClass: Class[_]): DS[Row] + def createDataFrame(data: util.List[_], beanClass: Class[_]): Dataset[Row] /* ------------------------------- * | Methods for creating DataSets | @@ -141,15 +154,15 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * * @since 2.0.0 */ - def emptyDataset[T: Encoder]: DS[T] + def emptyDataset[T: Encoder]: Dataset[T] /** * Creates a [[Dataset]] from a local Seq of data of a given type. This method requires an - * encoder (to convert a JVM object of type `T` to and from the internal Spark SQL representation) - * that is generally created automatically through implicits from a `SparkSession`, or can be - * created explicitly by calling static methods on `Encoders`. + * encoder (to convert a JVM object of type `T` to and from the internal Spark SQL + * representation) that is generally created automatically through implicits from a + * `SparkSession`, or can be created explicitly by calling static methods on `Encoders`. * - * == Example == + * ==Example== * * {{{ * @@ -170,15 +183,15 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * * @since 2.0.0 */ - def createDataset[T: Encoder](data: Seq[T]): DS[T] + def createDataset[T: Encoder](data: Seq[T]): Dataset[T] /** * Creates a [[Dataset]] from a `java.util.List` of a given type. This method requires an - * encoder (to convert a JVM object of type `T` to and from the internal Spark SQL representation) - * that is generally created automatically through implicits from a `SparkSession`, or can be - * created explicitly by calling static methods on `Encoders`. + * encoder (to convert a JVM object of type `T` to and from the internal Spark SQL + * representation) that is generally created automatically through implicits from a + * `SparkSession`, or can be created explicitly by calling static methods on `Encoders`. * - * == Java Example == + * ==Java Example== * * {{{ * List data = Arrays.asList("hello", "world"); @@ -187,131 +200,134 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * * @since 2.0.0 */ - def createDataset[T: Encoder](data: util.List[T]): DS[T] + def createDataset[T: Encoder](data: util.List[T]): Dataset[T] /** - * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements - * in a range from 0 to `end` (exclusive) with step value 1. + * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a + * range from 0 to `end` (exclusive) with step value 1. * * @since 2.0.0 */ - def range(end: Long): DS[lang.Long] + def range(end: Long): Dataset[lang.Long] /** - * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements - * in a range from `start` to `end` (exclusive) with step value 1. + * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a + * range from `start` to `end` (exclusive) with step value 1. * * @since 2.0.0 */ - def range(start: Long, end: Long): DS[lang.Long] + def range(start: Long, end: Long): Dataset[lang.Long] /** - * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements - * in a range from `start` to `end` (exclusive) with a step value. + * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a + * range from `start` to `end` (exclusive) with a step value. * * @since 2.0.0 */ - def range(start: Long, end: Long, step: Long): DS[lang.Long] + def range(start: Long, end: Long, step: Long): Dataset[lang.Long] /** - * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements - * in a range from `start` to `end` (exclusive) with a step value, with partition number - * specified. + * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a + * range from `start` to `end` (exclusive) with a step value, with partition number specified. * * @since 2.0.0 */ - def range(start: Long, end: Long, step: Long, numPartitions: Int): DS[lang.Long] + def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[lang.Long] /* ------------------------- * | Catalog-related methods | * ------------------------- */ + /** + * Interface through which the user may create, drop, alter or query underlying databases, + * tables, functions etc. + * + * @since 2.0.0 + */ + def catalog: Catalog + /** * Returns the specified table/view as a `DataFrame`. If it's a table, it must support batch - * reading and the returned DataFrame is the batch scan query plan of this table. If it's a view, - * the returned DataFrame is simply the query plan of the view, which can either be a batch or - * streaming query plan. - * - * @param tableName is either a qualified or unqualified name that designates a table or view. - * If a database is specified, it identifies the table/view from the database. - * Otherwise, it first attempts to find a temporary view with the given name - * and then match the table/view from the current database. - * Note that, the global temporary view database is also valid here. + * reading and the returned DataFrame is the batch scan query plan of this table. If it's a + * view, the returned DataFrame is simply the query plan of the view, which can either be a + * batch or streaming query plan. + * + * @param tableName + * is either a qualified or unqualified name that designates a table or view. If a database is + * specified, it identifies the table/view from the database. Otherwise, it first attempts to + * find a temporary view with the given name and then match the table/view from the current + * database. Note that, the global temporary view database is also valid here. * @since 2.0.0 */ - def table(tableName: String): DS[Row] + def table(tableName: String): Dataset[Row] /* ----------------- * | Everything else | * ----------------- */ /** - * Executes a SQL query substituting positional parameters by the given arguments, - * returning the result as a `DataFrame`. - * This API eagerly runs DDL/DML commands, but not for SELECT queries. - * - * @param sqlText A SQL statement with positional parameters to execute. - * @param args An array of Java/Scala objects that can be converted to - * SQL literal expressions. See - * - * Supported Data Types for supported value types in Scala/Java. - * For example, 1, "Steven", LocalDate.of(2023, 4, 2). - * A value can be also a `Column` of a literal or collection constructor functions - * such as `map()`, `array()`, `struct()`, in that case it is taken as is. + * Executes a SQL query substituting positional parameters by the given arguments, returning the + * result as a `DataFrame`. This API eagerly runs DDL/DML commands, but not for SELECT queries. + * + * @param sqlText + * A SQL statement with positional parameters to execute. + * @param args + * An array of Java/Scala objects that can be converted to SQL literal expressions. See Supported Data + * Types for supported value types in Scala/Java. For example, 1, "Steven", + * LocalDate.of(2023, 4, 2). A value can be also a `Column` of a literal or collection + * constructor functions such as `map()`, `array()`, `struct()`, in that case it is taken as + * is. * @since 3.5.0 */ @Experimental - def sql(sqlText: String, args: Array[_]): DS[Row] - - /** - * Executes a SQL query substituting named parameters by the given arguments, - * returning the result as a `DataFrame`. - * This API eagerly runs DDL/DML commands, but not for SELECT queries. - * - * @param sqlText A SQL statement with named parameters to execute. - * @param args A map of parameter names to Java/Scala objects that can be converted to - * SQL literal expressions. See - * - * Supported Data Types for supported value types in Scala/Java. - * For example, map keys: "rank", "name", "birthdate"; - * map values: 1, "Steven", LocalDate.of(2023, 4, 2). - * Map value can be also a `Column` of a literal or collection constructor - * functions such as `map()`, `array()`, `struct()`, in that case it is taken - * as is. + def sql(sqlText: String, args: Array[_]): Dataset[Row] + + /** + * Executes a SQL query substituting named parameters by the given arguments, returning the + * result as a `DataFrame`. This API eagerly runs DDL/DML commands, but not for SELECT queries. + * + * @param sqlText + * A SQL statement with named parameters to execute. + * @param args + * A map of parameter names to Java/Scala objects that can be converted to SQL literal + * expressions. See + * Supported Data Types for supported value types in Scala/Java. For example, map keys: + * "rank", "name", "birthdate"; map values: 1, "Steven", LocalDate.of(2023, 4, 2). Map value + * can be also a `Column` of a literal or collection constructor functions such as `map()`, + * `array()`, `struct()`, in that case it is taken as is. * @since 3.4.0 */ @Experimental - def sql(sqlText: String, args: Map[String, Any]): DS[Row] - - /** - * Executes a SQL query substituting named parameters by the given arguments, - * returning the result as a `DataFrame`. - * This API eagerly runs DDL/DML commands, but not for SELECT queries. - * - * @param sqlText A SQL statement with named parameters to execute. - * @param args A map of parameter names to Java/Scala objects that can be converted to - * SQL literal expressions. See - * - * Supported Data Types for supported value types in Scala/Java. - * For example, map keys: "rank", "name", "birthdate"; - * map values: 1, "Steven", LocalDate.of(2023, 4, 2). - * Map value can be also a `Column` of a literal or collection constructor - * functions such as `map()`, `array()`, `struct()`, in that case it is taken - * as is. + def sql(sqlText: String, args: Map[String, Any]): Dataset[Row] + + /** + * Executes a SQL query substituting named parameters by the given arguments, returning the + * result as a `DataFrame`. This API eagerly runs DDL/DML commands, but not for SELECT queries. + * + * @param sqlText + * A SQL statement with named parameters to execute. + * @param args + * A map of parameter names to Java/Scala objects that can be converted to SQL literal + * expressions. See + * Supported Data Types for supported value types in Scala/Java. For example, map keys: + * "rank", "name", "birthdate"; map values: 1, "Steven", LocalDate.of(2023, 4, 2). Map value + * can be also a `Column` of a literal or collection constructor functions such as `map()`, + * `array()`, `struct()`, in that case it is taken as is. * @since 3.4.0 */ @Experimental - def sql(sqlText: String, args: util.Map[String, Any]): DS[Row] = { + def sql(sqlText: String, args: util.Map[String, Any]): Dataset[Row] = { sql(sqlText, args.asScala.toMap) } /** - * Executes a SQL query using Spark, returning the result as a `DataFrame`. - * This API eagerly runs DDL/DML commands, but not for SELECT queries. + * Executes a SQL query using Spark, returning the result as a `DataFrame`. This API eagerly + * runs DDL/DML commands, but not for SELECT queries. * * @since 2.0.0 */ - def sql(sqlText: String): DS[Row] = sql(sqlText, Map.empty[String, Any]) + def sql(sqlText: String): Dataset[Row] = sql(sqlText, Map.empty[String, Any]) /** * Add a single artifact to the current session. @@ -333,6 +349,24 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C @Experimental def addArtifact(uri: URI): Unit + /** + * Add a single in-memory artifact to the session while preserving the directory structure + * specified by `target` under the session's working directory of that particular file + * extension. + * + * Supported target file extensions are .jar and .class. + * + * ==Example== + * {{{ + * addArtifact(bytesBar, "foo/bar.class") + * addArtifact(bytesFlat, "flat.class") + * // Directory structure of the session's working directory for class files would look like: + * // ${WORKING_DIR_FOR_CLASS_FILES}/flat.class + * // ${WORKING_DIR_FOR_CLASS_FILES}/foo/bar.class + * }}} + * + * @since 4.0.0 + */ @Experimental def addArtifact(bytes: Array[Byte], target: String): Unit @@ -367,6 +401,110 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C @scala.annotation.varargs def addArtifacts(uri: URI*): Unit + /** + * Add a tag to be assigned to all the operations started by this thread in this session. + * + * Often, a unit of execution in an application consists of multiple Spark executions. + * Application programmers can use this method to group all those jobs together and give a group + * tag. The application can use `org.apache.spark.sql.SparkSession.interruptTag` to cancel all + * running executions with this tag. For example: + * {{{ + * // In the main thread: + * spark.addTag("myjobs") + * spark.range(10).map(i => { Thread.sleep(10); i }).collect() + * + * // In a separate thread: + * spark.interruptTag("myjobs") + * }}} + * + * There may be multiple tags present at the same time, so different parts of application may + * use different tags to perform cancellation at different levels of granularity. + * + * @param tag + * The tag to be added. Cannot contain ',' (comma) character or be an empty string. + * + * @since 4.0.0 + */ + def addTag(tag: String): Unit + + /** + * Remove a tag previously added to be assigned to all the operations started by this thread in + * this session. Noop if such a tag was not added earlier. + * + * @param tag + * The tag to be removed. Cannot contain ',' (comma) character or be an empty string. + * + * @since 4.0.0 + */ + def removeTag(tag: String): Unit + + /** + * Get the operation tags that are currently set to be assigned to all the operations started by + * this thread in this session. + * + * @since 4.0.0 + */ + def getTags(): Set[String] + + /** + * Clear the current thread's operation tags. + * + * @since 4.0.0 + */ + def clearTags(): Unit + + /** + * Request to interrupt all currently running operations of this session. + * + * @note + * This method will wait up to 60 seconds for the interruption request to be issued. + * + * @return + * Sequence of operation IDs requested to be interrupted. + * + * @since 4.0.0 + */ + def interruptAll(): Seq[String] + + /** + * Request to interrupt all currently running operations of this session with the given job tag. + * + * @note + * This method will wait up to 60 seconds for the interruption request to be issued. + * + * @return + * Sequence of operation IDs requested to be interrupted. + * + * @since 4.0.0 + */ + def interruptTag(tag: String): Seq[String] + + /** + * Request to interrupt an operation of this session, given its operation ID. + * + * @note + * This method will wait up to 60 seconds for the interruption request to be issued. + * + * @return + * The operation ID requested to be interrupted, as a single-element sequence, or an empty + * sequence if the operation is not started by this session. + * + * @since 4.0.0 + */ + def interruptOperation(operationId: String): Seq[String] + + /** + * Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a + * `DataFrame`. + * {{{ + * sparkSession.read.parquet("/path/to/file.parquet") + * sparkSession.read.schema(schema).json("/path/to/file.json") + * }}} + * + * @since 2.0.0 + */ + def read: DataFrameReader + /** * Executes some code block and prints to stdout the time taken to execute the block. This is * available in Scala only and is used primarily for interactive testing and debugging. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQuery.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQuery.scala new file mode 100644 index 0000000000000..0aeb3518facd8 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQuery.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api + +import _root_.java.util.UUID +import _root_.java.util.concurrent.TimeoutException + +import org.apache.spark.annotation.Evolving +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamingQueryProgress, StreamingQueryStatus} + +/** + * A handle to a query that is executing continuously in the background as new data arrives. All + * these methods are thread-safe. + * @since 2.0.0 + */ +@Evolving +trait StreamingQuery { + + /** + * Returns the user-specified name of the query, or null if not specified. This name can be + * specified in the `org.apache.spark.sql.streaming.DataStreamWriter` as + * `dataframe.writeStream.queryName("query").start()`. This name, if set, must be unique across + * all active queries. + * + * @since 2.0.0 + */ + def name: String + + /** + * Returns the unique id of this query that persists across restarts from checkpoint data. That + * is, this id is generated when a query is started for the first time, and will be the same + * every time it is restarted from checkpoint data. Also see [[runId]]. + * + * @since 2.1.0 + */ + def id: UUID + + /** + * Returns the unique id of this run of the query. That is, every start/restart of a query will + * generate a unique runId. Therefore, every time a query is restarted from checkpoint, it will + * have the same [[id]] but different [[runId]]s. + */ + def runId: UUID + + /** + * Returns the `SparkSession` associated with `this`. + * + * @since 2.0.0 + */ + def sparkSession: SparkSession + + /** + * Returns `true` if this query is actively running. + * + * @since 2.0.0 + */ + def isActive: Boolean + + /** + * Returns the [[org.apache.spark.sql.streaming.StreamingQueryException]] if the query was + * terminated by an exception. + * + * @since 2.0.0 + */ + def exception: Option[StreamingQueryException] + + /** + * Returns the current status of the query. + * + * @since 2.0.2 + */ + def status: StreamingQueryStatus + + /** + * Returns an array of the most recent [[org.apache.spark.sql.streaming.StreamingQueryProgress]] + * updates for this query. The number of progress updates retained for each stream is configured + * by Spark session configuration `spark.sql.streaming.numRecentProgressUpdates`. + * + * @since 2.1.0 + */ + def recentProgress: Array[StreamingQueryProgress] + + /** + * Returns the most recent [[org.apache.spark.sql.streaming.StreamingQueryProgress]] update of + * this streaming query. + * + * @since 2.1.0 + */ + def lastProgress: StreamingQueryProgress + + /** + * Waits for the termination of `this` query, either by `query.stop()` or by an exception. If + * the query has terminated with an exception, then the exception will be thrown. + * + * If the query has terminated, then all subsequent calls to this method will either return + * immediately (if the query was terminated by `stop()`), or throw the exception immediately (if + * the query has terminated with exception). + * + * @throws org.apache.spark.sql.streaming.StreamingQueryException + * if the query has terminated with an exception. + * + * @since 2.0.0 + */ + @throws[StreamingQueryException] + def awaitTermination(): Unit + + /** + * Waits for the termination of `this` query, either by `query.stop()` or by an exception. If + * the query has terminated with an exception, then the exception will be thrown. Otherwise, it + * returns whether the query has terminated or not within the `timeoutMs` milliseconds. + * + * If the query has terminated, then all subsequent calls to this method will either return + * `true` immediately (if the query was terminated by `stop()`), or throw the exception + * immediately (if the query has terminated with exception). + * + * @throws org.apache.spark.sql.streaming.StreamingQueryException + * if the query has terminated with an exception + * + * @since 2.0.0 + */ + @throws[StreamingQueryException] + def awaitTermination(timeoutMs: Long): Boolean + + /** + * Blocks until all available data in the source has been processed and committed to the sink. + * This method is intended for testing. Note that in the case of continually arriving data, this + * method may block forever. Additionally, this method is only guaranteed to block until data + * that has been synchronously appended data to a + * `org.apache.spark.sql.execution.streaming.Source` prior to invocation. (i.e. `getOffset` must + * immediately reflect the addition). + * @since 2.0.0 + */ + def processAllAvailable(): Unit + + /** + * Stops the execution of this query if it is running. This waits until the termination of the + * query execution threads or until a timeout is hit. + * + * By default stop will block indefinitely. You can configure a timeout by the configuration + * `spark.sql.streaming.stopTimeout`. A timeout of 0 (or negative) milliseconds will block + * indefinitely. If a `TimeoutException` is thrown, users can retry stopping the stream. If the + * issue persists, it is advisable to kill the Spark application. + * + * @since 2.0.0 + */ + @throws[TimeoutException] + def stop(): Unit + + /** + * Prints the physical plan to the console for debugging purposes. + * @since 2.0.0 + */ + def explain(): Unit + + /** + * Prints the physical plan to the console for debugging purposes. + * + * @param extended + * whether to do extended explain or not + * @since 2.0.0 + */ + def explain(extended: Boolean): Unit +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala index 4611393f0f7ec..c11e266827ff9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala @@ -50,9 +50,12 @@ abstract class UDFRegistration { * spark.udf.register("stringLit", bar.asNonNullable()) * }}} * - * @param name the name of the UDF. - * @param udf the UDF needs to be registered. - * @return the registered UDF. + * @param name + * the name of the UDF. + * @param udf + * the UDF needs to be registered. + * @return + * the registered UDF. * * @since 2.2.0 */ @@ -117,11 +120,12 @@ abstract class UDFRegistration { | registerJavaUDF(name, ToScalaUDF(f), returnType, $i) |}""".stripMargin) } - */ + */ /** * Registers a deterministic Scala closure of 0 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { @@ -130,200 +134,964 @@ abstract class UDFRegistration { /** * Registers a deterministic Scala closure of 1 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { + def register[RT: TypeTag, A1: TypeTag]( + name: String, + func: Function1[A1, RT]): UserDefinedFunction = { registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]]) } /** * Registers a deterministic Scala closure of 2 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]]) + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag]( + name: String, + func: Function2[A1, A2, RT]): UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]]) } /** * Registers a deterministic Scala closure of 3 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]]) + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag]( + name: String, + func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]]) } /** * Registers a deterministic Scala closure of 4 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]]) + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag]( + name: String, + func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]]) } /** * Registers a deterministic Scala closure of 5 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]]) + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag]( + name: String, + func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]]) } /** * Registers a deterministic Scala closure of 6 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]]) + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag]( + name: String, + func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]]) } /** * Registers a deterministic Scala closure of 7 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]]) + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag]( + name: String, + func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]]) } /** * Registers a deterministic Scala closure of 8 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]]) + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag]( + name: String, + func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]], + implicitly[TypeTag[A8]]) } /** * Registers a deterministic Scala closure of 9 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]]) + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag]( + name: String, + func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]], + implicitly[TypeTag[A8]], + implicitly[TypeTag[A9]]) } /** * Registers a deterministic Scala closure of 10 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]]) + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag]( + name: String, + func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]], + implicitly[TypeTag[A8]], + implicitly[TypeTag[A9]], + implicitly[TypeTag[A10]]) } /** * Registers a deterministic Scala closure of 11 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]]) + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag]( + name: String, + func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]], + implicitly[TypeTag[A8]], + implicitly[TypeTag[A9]], + implicitly[TypeTag[A10]], + implicitly[TypeTag[A11]]) } /** * Registers a deterministic Scala closure of 12 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]]) + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag]( + name: String, + func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]) + : UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]], + implicitly[TypeTag[A8]], + implicitly[TypeTag[A9]], + implicitly[TypeTag[A10]], + implicitly[TypeTag[A11]], + implicitly[TypeTag[A12]]) } /** * Registers a deterministic Scala closure of 13 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]]) + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag]( + name: String, + func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]) + : UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]], + implicitly[TypeTag[A8]], + implicitly[TypeTag[A9]], + implicitly[TypeTag[A10]], + implicitly[TypeTag[A11]], + implicitly[TypeTag[A12]], + implicitly[TypeTag[A13]]) } /** * Registers a deterministic Scala closure of 14 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]]) + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag, + A14: TypeTag]( + name: String, + func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]) + : UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]], + implicitly[TypeTag[A8]], + implicitly[TypeTag[A9]], + implicitly[TypeTag[A10]], + implicitly[TypeTag[A11]], + implicitly[TypeTag[A12]], + implicitly[TypeTag[A13]], + implicitly[TypeTag[A14]]) } /** * Registers a deterministic Scala closure of 15 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]]) + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag, + A14: TypeTag, + A15: TypeTag]( + name: String, + func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]) + : UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]], + implicitly[TypeTag[A8]], + implicitly[TypeTag[A9]], + implicitly[TypeTag[A10]], + implicitly[TypeTag[A11]], + implicitly[TypeTag[A12]], + implicitly[TypeTag[A13]], + implicitly[TypeTag[A14]], + implicitly[TypeTag[A15]]) } /** * Registers a deterministic Scala closure of 16 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]]) + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag, + A14: TypeTag, + A15: TypeTag, + A16: TypeTag]( + name: String, + func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]) + : UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]], + implicitly[TypeTag[A8]], + implicitly[TypeTag[A9]], + implicitly[TypeTag[A10]], + implicitly[TypeTag[A11]], + implicitly[TypeTag[A12]], + implicitly[TypeTag[A13]], + implicitly[TypeTag[A14]], + implicitly[TypeTag[A15]], + implicitly[TypeTag[A16]]) } /** * Registers a deterministic Scala closure of 17 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]]) + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag, + A14: TypeTag, + A15: TypeTag, + A16: TypeTag, + A17: TypeTag]( + name: String, + func: Function17[ + A1, + A2, + A3, + A4, + A5, + A6, + A7, + A8, + A9, + A10, + A11, + A12, + A13, + A14, + A15, + A16, + A17, + RT]): UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]], + implicitly[TypeTag[A8]], + implicitly[TypeTag[A9]], + implicitly[TypeTag[A10]], + implicitly[TypeTag[A11]], + implicitly[TypeTag[A12]], + implicitly[TypeTag[A13]], + implicitly[TypeTag[A14]], + implicitly[TypeTag[A15]], + implicitly[TypeTag[A16]], + implicitly[TypeTag[A17]]) } /** * Registers a deterministic Scala closure of 18 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]]) + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag, + A14: TypeTag, + A15: TypeTag, + A16: TypeTag, + A17: TypeTag, + A18: TypeTag]( + name: String, + func: Function18[ + A1, + A2, + A3, + A4, + A5, + A6, + A7, + A8, + A9, + A10, + A11, + A12, + A13, + A14, + A15, + A16, + A17, + A18, + RT]): UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]], + implicitly[TypeTag[A8]], + implicitly[TypeTag[A9]], + implicitly[TypeTag[A10]], + implicitly[TypeTag[A11]], + implicitly[TypeTag[A12]], + implicitly[TypeTag[A13]], + implicitly[TypeTag[A14]], + implicitly[TypeTag[A15]], + implicitly[TypeTag[A16]], + implicitly[TypeTag[A17]], + implicitly[TypeTag[A18]]) } /** * Registers a deterministic Scala closure of 19 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]]) + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag, + A14: TypeTag, + A15: TypeTag, + A16: TypeTag, + A17: TypeTag, + A18: TypeTag, + A19: TypeTag]( + name: String, + func: Function19[ + A1, + A2, + A3, + A4, + A5, + A6, + A7, + A8, + A9, + A10, + A11, + A12, + A13, + A14, + A15, + A16, + A17, + A18, + A19, + RT]): UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]], + implicitly[TypeTag[A8]], + implicitly[TypeTag[A9]], + implicitly[TypeTag[A10]], + implicitly[TypeTag[A11]], + implicitly[TypeTag[A12]], + implicitly[TypeTag[A13]], + implicitly[TypeTag[A14]], + implicitly[TypeTag[A15]], + implicitly[TypeTag[A16]], + implicitly[TypeTag[A17]], + implicitly[TypeTag[A18]], + implicitly[TypeTag[A19]]) } /** * Registers a deterministic Scala closure of 20 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]], implicitly[TypeTag[A20]]) + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag, + A14: TypeTag, + A15: TypeTag, + A16: TypeTag, + A17: TypeTag, + A18: TypeTag, + A19: TypeTag, + A20: TypeTag]( + name: String, + func: Function20[ + A1, + A2, + A3, + A4, + A5, + A6, + A7, + A8, + A9, + A10, + A11, + A12, + A13, + A14, + A15, + A16, + A17, + A18, + A19, + A20, + RT]): UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]], + implicitly[TypeTag[A8]], + implicitly[TypeTag[A9]], + implicitly[TypeTag[A10]], + implicitly[TypeTag[A11]], + implicitly[TypeTag[A12]], + implicitly[TypeTag[A13]], + implicitly[TypeTag[A14]], + implicitly[TypeTag[A15]], + implicitly[TypeTag[A16]], + implicitly[TypeTag[A17]], + implicitly[TypeTag[A18]], + implicitly[TypeTag[A19]], + implicitly[TypeTag[A20]]) } /** * Registers a deterministic Scala closure of 21 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]], implicitly[TypeTag[A20]], implicitly[TypeTag[A21]]) + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag, + A14: TypeTag, + A15: TypeTag, + A16: TypeTag, + A17: TypeTag, + A18: TypeTag, + A19: TypeTag, + A20: TypeTag, + A21: TypeTag]( + name: String, + func: Function21[ + A1, + A2, + A3, + A4, + A5, + A6, + A7, + A8, + A9, + A10, + A11, + A12, + A13, + A14, + A15, + A16, + A17, + A18, + A19, + A20, + A21, + RT]): UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]], + implicitly[TypeTag[A8]], + implicitly[TypeTag[A9]], + implicitly[TypeTag[A10]], + implicitly[TypeTag[A11]], + implicitly[TypeTag[A12]], + implicitly[TypeTag[A13]], + implicitly[TypeTag[A14]], + implicitly[TypeTag[A15]], + implicitly[TypeTag[A16]], + implicitly[TypeTag[A17]], + implicitly[TypeTag[A18]], + implicitly[TypeTag[A19]], + implicitly[TypeTag[A20]], + implicitly[TypeTag[A21]]) } /** * Registers a deterministic Scala closure of 22 arguments as user-defined function (UDF). - * @tparam RT return type of UDF. + * @tparam RT + * return type of UDF. * @since 1.3.0 */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { - registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]], implicitly[TypeTag[A20]], implicitly[TypeTag[A21]], implicitly[TypeTag[A22]]) + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag, + A14: TypeTag, + A15: TypeTag, + A16: TypeTag, + A17: TypeTag, + A18: TypeTag, + A19: TypeTag, + A20: TypeTag, + A21: TypeTag, + A22: TypeTag]( + name: String, + func: Function22[ + A1, + A2, + A3, + A4, + A5, + A6, + A7, + A8, + A9, + A10, + A11, + A12, + A13, + A14, + A15, + A16, + A17, + A18, + A19, + A20, + A21, + A22, + RT]): UserDefinedFunction = { + registerScalaUDF( + name, + func, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]], + implicitly[TypeTag[A8]], + implicitly[TypeTag[A9]], + implicitly[TypeTag[A10]], + implicitly[TypeTag[A11]], + implicitly[TypeTag[A12]], + implicitly[TypeTag[A13]], + implicitly[TypeTag[A14]], + implicitly[TypeTag[A15]], + implicitly[TypeTag[A16]], + implicitly[TypeTag[A17]], + implicitly[TypeTag[A18]], + implicitly[TypeTag[A19]], + implicitly[TypeTag[A20]], + implicitly[TypeTag[A21]], + implicitly[TypeTag[A22]]) } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -405,7 +1173,10 @@ abstract class UDFRegistration { * Register a deterministic Java UDF9 instance as user-defined function (UDF). * @since 1.3.0 */ - def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + def register( + name: String, + f: UDF9[_, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { registerJavaUDF(name, ToScalaUDF(f), returnType, 9) } @@ -413,7 +1184,10 @@ abstract class UDFRegistration { * Register a deterministic Java UDF10 instance as user-defined function (UDF). * @since 1.3.0 */ - def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + def register( + name: String, + f: UDF10[_, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { registerJavaUDF(name, ToScalaUDF(f), returnType, 10) } @@ -421,7 +1195,10 @@ abstract class UDFRegistration { * Register a deterministic Java UDF11 instance as user-defined function (UDF). * @since 1.3.0 */ - def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + def register( + name: String, + f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { registerJavaUDF(name, ToScalaUDF(f), returnType, 11) } @@ -429,7 +1206,10 @@ abstract class UDFRegistration { * Register a deterministic Java UDF12 instance as user-defined function (UDF). * @since 1.3.0 */ - def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + def register( + name: String, + f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { registerJavaUDF(name, ToScalaUDF(f), returnType, 12) } @@ -437,7 +1217,10 @@ abstract class UDFRegistration { * Register a deterministic Java UDF13 instance as user-defined function (UDF). * @since 1.3.0 */ - def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + def register( + name: String, + f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { registerJavaUDF(name, ToScalaUDF(f), returnType, 13) } @@ -445,7 +1228,10 @@ abstract class UDFRegistration { * Register a deterministic Java UDF14 instance as user-defined function (UDF). * @since 1.3.0 */ - def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + def register( + name: String, + f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { registerJavaUDF(name, ToScalaUDF(f), returnType, 14) } @@ -453,7 +1239,10 @@ abstract class UDFRegistration { * Register a deterministic Java UDF15 instance as user-defined function (UDF). * @since 1.3.0 */ - def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + def register( + name: String, + f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { registerJavaUDF(name, ToScalaUDF(f), returnType, 15) } @@ -461,7 +1250,10 @@ abstract class UDFRegistration { * Register a deterministic Java UDF16 instance as user-defined function (UDF). * @since 1.3.0 */ - def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + def register( + name: String, + f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { registerJavaUDF(name, ToScalaUDF(f), returnType, 16) } @@ -469,7 +1261,10 @@ abstract class UDFRegistration { * Register a deterministic Java UDF17 instance as user-defined function (UDF). * @since 1.3.0 */ - def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + def register( + name: String, + f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { registerJavaUDF(name, ToScalaUDF(f), returnType, 17) } @@ -477,7 +1272,10 @@ abstract class UDFRegistration { * Register a deterministic Java UDF18 instance as user-defined function (UDF). * @since 1.3.0 */ - def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + def register( + name: String, + f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { registerJavaUDF(name, ToScalaUDF(f), returnType, 18) } @@ -485,7 +1283,10 @@ abstract class UDFRegistration { * Register a deterministic Java UDF19 instance as user-defined function (UDF). * @since 1.3.0 */ - def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + def register( + name: String, + f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { registerJavaUDF(name, ToScalaUDF(f), returnType, 19) } @@ -493,7 +1294,10 @@ abstract class UDFRegistration { * Register a deterministic Java UDF20 instance as user-defined function (UDF). * @since 1.3.0 */ - def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + def register( + name: String, + f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { registerJavaUDF(name, ToScalaUDF(f), returnType, 20) } @@ -501,7 +1305,10 @@ abstract class UDFRegistration { * Register a deterministic Java UDF21 instance as user-defined function (UDF). * @since 1.3.0 */ - def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + def register( + name: String, + f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { registerJavaUDF(name, ToScalaUDF(f), returnType, 21) } @@ -509,7 +1316,10 @@ abstract class UDFRegistration { * Register a deterministic Java UDF22 instance as user-defined function (UDF). * @since 1.3.0 */ - def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + def register( + name: String, + f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { registerJavaUDF(name, ToScalaUDF(f), returnType, 22) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/interface.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalog/interface.scala similarity index 95% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/interface.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalog/interface.scala index 33e9007ac7e2b..3a3ba9d261326 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/interface.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalog/interface.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalog import javax.annotation.Nullable +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.DefinedByConstructorParams // Note: all classes here are expected to be wrapped in Datasets and so must extend @@ -31,16 +32,13 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams * name of the catalog * @param description * description of the catalog - * @since 3.5.0 + * @since 3.4.0 */ class CatalogMetadata(val name: String, @Nullable val description: String) extends DefinedByConstructorParams { - override def toString: String = { - "Catalog[" + - s"name='$name', " + - Option(description).map { d => s"description='$d'] " }.getOrElse("]") - } + override def toString: String = + s"Catalog[name='$name', ${Option(description).map(d => s"description='$d'").getOrElse("")}]" } /** @@ -54,8 +52,9 @@ class CatalogMetadata(val name: String, @Nullable val description: String) * description of the database. * @param locationUri * path (in the form of a uri) to data files. - * @since 3.5.0 + * @since 2.0.0 */ +@Stable class Database( val name: String, @Nullable val catalog: String, @@ -92,8 +91,9 @@ class Database( * type of the table (e.g. view, table). * @param isTemporary * whether the table is a temporary table. - * @since 3.5.0 + * @since 2.0.0 */ +@Stable class Table( val name: String, @Nullable val catalog: String, @@ -155,8 +155,9 @@ class Table( * whether the column is a bucket column. * @param isCluster * whether the column is a clustering column. - * @since 3.5.0 + * @since 2.0.0 */ +@Stable class Column( val name: String, @Nullable val description: String, @@ -205,8 +206,9 @@ class Column( * the fully qualified class name of the function. * @param isTemporary * whether the function is a temporary function or not. - * @since 3.5.0 + * @since 2.0.0 */ +@Stable class Function( val name: String, @Nullable val catalog: String, diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/DefinedByConstructorParams.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/DefinedByConstructorParams.scala index fc6bc2095a821..efd9d457ef7d4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/DefinedByConstructorParams.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/DefinedByConstructorParams.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst /** - * A helper trait to create [[org.apache.spark.sql.catalyst.encoders.ExpressionEncoder]]s - * for classes whose fields are entirely defined by constructor params but should not be - * case classes. + * A helper trait to create [[org.apache.spark.sql.catalyst.encoders.ExpressionEncoder]]s for + * classes whose fields are entirely defined by constructor params but should not be case classes. */ trait DefinedByConstructorParams diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index f85e96da2be11..8d0103ca69635 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -36,10 +36,13 @@ import org.apache.spark.util.ArrayImplicits._ * Type-inference utilities for POJOs and Java collections. */ object JavaTypeInference { + /** * Infers the corresponding SQL data type of a Java type. - * @param beanType Java type - * @return (SQL data type, nullable) + * @param beanType + * Java type + * @return + * (SQL data type, nullable) */ def inferDataType(beanType: Type): (DataType, Boolean) = { val encoder = encoderFor(beanType) @@ -60,8 +63,10 @@ object JavaTypeInference { encoderFor(beanType, Set.empty).asInstanceOf[AgnosticEncoder[T]] } - private def encoderFor(t: Type, seenTypeSet: Set[Class[_]], - typeVariables: Map[TypeVariable[_], Type] = Map.empty): AgnosticEncoder[_] = t match { + private def encoderFor( + t: Type, + seenTypeSet: Set[Class[_]], + typeVariables: Map[TypeVariable[_], Type] = Map.empty): AgnosticEncoder[_] = t match { case c: Class[_] if c == java.lang.Boolean.TYPE => PrimitiveBooleanEncoder case c: Class[_] if c == java.lang.Byte.TYPE => PrimitiveByteEncoder @@ -94,14 +99,22 @@ object JavaTypeInference { case c: Class[_] if c.isEnum => JavaEnumEncoder(ClassTag(c)) case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => - val udt = c.getAnnotation(classOf[SQLUserDefinedType]).udt() - .getConstructor().newInstance().asInstanceOf[UserDefinedType[Any]] + val udt = c + .getAnnotation(classOf[SQLUserDefinedType]) + .udt() + .getConstructor() + .newInstance() + .asInstanceOf[UserDefinedType[Any]] val udtClass = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt() UDTEncoder(udt, udtClass) case c: Class[_] if UDTRegistration.exists(c.getName) => - val udt = UDTRegistration.getUDTFor(c.getName).get.getConstructor(). - newInstance().asInstanceOf[UserDefinedType[Any]] + val udt = UDTRegistration + .getUDTFor(c.getName) + .get + .getConstructor() + .newInstance() + .asInstanceOf[UserDefinedType[Any]] UDTEncoder(udt, udt.getClass) case c: Class[_] if c.isArray => @@ -160,7 +173,8 @@ object JavaTypeInference { def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { val beanInfo = Introspector.getBeanInfo(beanClass) - beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") + beanInfo.getPropertyDescriptors + .filterNot(_.getName == "class") .filterNot(_.getName == "declaringClass") .filter(_.getReadMethod != null) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f204421b0add2..cd12cbd267cc4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -55,9 +55,8 @@ object ScalaReflection extends ScalaReflection { import scala.collection.Map /** - * Synchronize to prevent concurrent usage of `<:<` operator. - * This operator is not thread safe in any current version of scala; i.e. - * (2.11.12, 2.12.10, 2.13.0-M5). + * Synchronize to prevent concurrent usage of `<:<` operator. This operator is not thread safe + * in any current version of scala; i.e. (2.11.12, 2.12.10, 2.13.0-M5). * * See https://github.com/scala/bug/issues/10766 */ @@ -91,11 +90,11 @@ object ScalaReflection extends ScalaReflection { /** * Workaround for [[https://github.com/scala/bug/issues/12190 Scala bug #12190]] * - * `ClassSymbol.selfType` can throw an exception in case of cyclic annotation reference - * in Java classes. A retry of this operation will succeed as the class which defines the - * cycle is now resolved. It can however expose further recursive annotation references, so - * we keep retrying until we exhaust our retry threshold. Default threshold is set to 5 - * to allow for a few level of cyclic references. + * `ClassSymbol.selfType` can throw an exception in case of cyclic annotation reference in Java + * classes. A retry of this operation will succeed as the class which defines the cycle is now + * resolved. It can however expose further recursive annotation references, so we keep retrying + * until we exhaust our retry threshold. Default threshold is set to 5 to allow for a few level + * of cyclic references. */ @tailrec private def selfType(clsSymbol: ClassSymbol, tries: Int = 5): Type = { @@ -107,7 +106,7 @@ object ScalaReflection extends ScalaReflection { // Retry on Symbols#CyclicReference if we haven't exhausted our retry limit selfType(clsSymbol, tries - 1) case Failure(e: RuntimeException) - if e.getMessage.contains("illegal cyclic reference") && tries > 1 => + if e.getMessage.contains("illegal cyclic reference") && tries > 1 => // UnPickler.unpickle wraps the original Symbols#CyclicReference exception into a runtime // exception and does not set the cause, so we inspect the message. The previous case // statement is useful for Java classes while this one is for Scala classes. @@ -131,14 +130,14 @@ object ScalaReflection extends ScalaReflection { } /** - * Returns the full class name for a type. The returned name is the canonical - * Scala name, where each component is separated by a period. It is NOT the - * Java-equivalent runtime name (no dollar signs). + * Returns the full class name for a type. The returned name is the canonical Scala name, where + * each component is separated by a period. It is NOT the Java-equivalent runtime name (no + * dollar signs). * - * In simple cases, both the Scala and Java names are the same, however when Scala - * generates constructs that do not map to a Java equivalent, such as singleton objects - * or nested classes in package objects, it uses the dollar sign ($) to create - * synthetic classes, emulating behaviour in Java bytecode. + * In simple cases, both the Scala and Java names are the same, however when Scala generates + * constructs that do not map to a Java equivalent, such as singleton objects or nested classes + * in package objects, it uses the dollar sign ($) to create synthetic classes, emulating + * behaviour in Java bytecode. */ def getClassNameFromType(tpe: `Type`): String = { erasure(tpe).dealias.typeSymbol.asClass.fullName @@ -152,20 +151,25 @@ object ScalaReflection extends ScalaReflection { case class Schema(dataType: DataType, nullable: Boolean) - /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ + /** + * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. + */ def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) - /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ + /** + * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. + */ def schemaFor(tpe: `Type`): Schema = { val enc = encoderFor(tpe) Schema(enc.dataType, enc.nullable) } /** - * Finds an accessible constructor with compatible parameters. This is a more flexible search than - * the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible + * Finds an accessible constructor with compatible parameters. This is a more flexible search + * than the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible * matching constructor is returned if it exists. Otherwise, we check for additional compatible - * constructors defined in the companion object as `apply` methods. Otherwise, it returns `None`. + * constructors defined in the companion object as `apply` methods. Otherwise, it returns + * `None`. */ def findConstructor[T](cls: Class[T], paramTypes: Seq[Class[_]]): Option[Seq[AnyRef] => T] = { Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) match { @@ -174,24 +178,28 @@ object ScalaReflection extends ScalaReflection { val companion = mirror.staticClass(cls.getName).companion val moduleMirror = mirror.reflectModule(companion.asModule) val applyMethods = companion.asTerm.typeSignature - .member(universe.TermName("apply")).asTerm.alternatives - applyMethods.find { method => - val params = method.typeSignature.paramLists.head - // Check that the needed params are the same length and of matching types - params.size == paramTypes.size && - params.zip(paramTypes).forall { case(ps, pc) => - ps.typeSignature.typeSymbol == mirror.classSymbol(pc) + .member(universe.TermName("apply")) + .asTerm + .alternatives + applyMethods + .find { method => + val params = method.typeSignature.paramLists.head + // Check that the needed params are the same length and of matching types + params.size == paramTypes.size && + params.zip(paramTypes).forall { case (ps, pc) => + ps.typeSignature.typeSymbol == mirror.classSymbol(pc) + } } - }.map { applyMethodSymbol => - val expectedArgsCount = applyMethodSymbol.typeSignature.paramLists.head.size - val instanceMirror = mirror.reflect(moduleMirror.instance) - val method = instanceMirror.reflectMethod(applyMethodSymbol.asMethod) - (_args: Seq[AnyRef]) => { - // Drop the "outer" argument if it is provided - val args = if (_args.size == expectedArgsCount) _args else _args.tail - method.apply(args: _*).asInstanceOf[T] + .map { applyMethodSymbol => + val expectedArgsCount = applyMethodSymbol.typeSignature.paramLists.head.size + val instanceMirror = mirror.reflect(moduleMirror.instance) + val method = instanceMirror.reflectMethod(applyMethodSymbol.asMethod) + (_args: Seq[AnyRef]) => { + // Drop the "outer" argument if it is provided + val args = if (_args.size == expectedArgsCount) _args else _args.tail + method.apply(args: _*).asInstanceOf[T] + } } - } } } @@ -201,8 +209,10 @@ object ScalaReflection extends ScalaReflection { def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects { tpe.dealias match { // `Option` is a `Product`, but we don't wanna treat `Option[Int]` as a struct type. - case t if isSubtype(t, localTypeOf[Option[_]]) => definedByConstructorParams(t.typeArgs.head) - case _ => isSubtype(tpe.dealias, localTypeOf[Product]) || + case t if isSubtype(t, localTypeOf[Option[_]]) => + definedByConstructorParams(t.typeArgs.head) + case _ => + isSubtype(tpe.dealias, localTypeOf[Product]) || isSubtype(tpe.dealias, localTypeOf[DefinedByConstructorParams]) } } @@ -214,16 +224,15 @@ object ScalaReflection extends ScalaReflection { /** * Create an [[AgnosticEncoder]] from a [[TypeTag]]. * - * If the given type is not supported, i.e. there is no encoder can be built for this type, - * an [[SparkUnsupportedOperationException]] will be thrown with detailed error message to - * explain the type path walked so far and which class we are not supporting. - * There are 4 kinds of type path: - * * the root type: `root class: "abc.xyz.MyClass"` - * * the value type of [[Option]]: `option value class: "abc.xyz.MyClass"` - * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` - * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` + * If the given type is not supported, i.e. there is no encoder can be built for this type, an + * [[SparkUnsupportedOperationException]] will be thrown with detailed error message to explain + * the type path walked so far and which class we are not supporting. There are 4 kinds of type + * path: * the root type: `root class: "abc.xyz.MyClass"` * the value type of [[Option]]: + * `option value class: "abc.xyz.MyClass"` * the element type of [[Array]] or [[Seq]]: `array + * element class: "abc.xyz.MyClass"` * the field of [[Product]]: `field (class: + * "abc.xyz.MyClass", name: "myField")` */ - def encoderFor[E : TypeTag]: AgnosticEncoder[E] = { + def encoderFor[E: TypeTag]: AgnosticEncoder[E] = { encoderFor(typeTag[E].in(mirror).tpe).asInstanceOf[AgnosticEncoder[E]] } @@ -239,13 +248,12 @@ object ScalaReflection extends ScalaReflection { /** * Create an [[AgnosticEncoder]] for a [[Type]]. */ - def encoderFor( - tpe: `Type`, - isRowEncoderSupported: Boolean = false): AgnosticEncoder[_] = cleanUpReflectionObjects { - val clsName = getClassNameFromType(tpe) - val walkedTypePath = WalkedTypePath().recordRoot(clsName) - encoderFor(tpe, Set.empty, walkedTypePath, isRowEncoderSupported) - } + def encoderFor(tpe: `Type`, isRowEncoderSupported: Boolean = false): AgnosticEncoder[_] = + cleanUpReflectionObjects { + val clsName = getClassNameFromType(tpe) + val walkedTypePath = WalkedTypePath().recordRoot(clsName) + encoderFor(tpe, Set.empty, walkedTypePath, isRowEncoderSupported) + } private def encoderFor( tpe: `Type`, @@ -327,14 +335,22 @@ object ScalaReflection extends ScalaReflection { // UDT encoders case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => - val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt(). - getConstructor().newInstance().asInstanceOf[UserDefinedType[Any]] + val udt = getClassFromType(t) + .getAnnotation(classOf[SQLUserDefinedType]) + .udt() + .getConstructor() + .newInstance() + .asInstanceOf[UserDefinedType[Any]] val udtClass = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt() UDTEncoder(udt, udtClass) case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). - newInstance().asInstanceOf[UserDefinedType[Any]] + val udt = UDTRegistration + .getUDTFor(getClassNameFromType(t)) + .get + .getConstructor() + .newInstance() + .asInstanceOf[UserDefinedType[Any]] UDTEncoder(udt, udt.getClass) // Complex encoders @@ -380,20 +396,17 @@ object ScalaReflection extends ScalaReflection { if (seenTypeSet.contains(t)) { throw ExecutionErrors.cannotHaveCircularReferencesInClassError(t.toString) } - val params = getConstructorParameters(t).map { - case (fieldName, fieldType) => - if (SourceVersion.isKeyword(fieldName) || - !SourceVersion.isIdentifier(encodeFieldNameToIdentifier(fieldName))) { - throw ExecutionErrors.cannotUseInvalidJavaIdentifierAsFieldNameError( - fieldName, - path) - } - val encoder = encoderFor( - fieldType, - seenTypeSet + t, - path.recordField(getClassNameFromType(fieldType), fieldName), - isRowEncoderSupported) - EncoderField(fieldName, encoder, encoder.nullable, Metadata.empty) + val params = getConstructorParameters(t).map { case (fieldName, fieldType) => + if (SourceVersion.isKeyword(fieldName) || + !SourceVersion.isIdentifier(encodeFieldNameToIdentifier(fieldName))) { + throw ExecutionErrors.cannotUseInvalidJavaIdentifierAsFieldNameError(fieldName, path) + } + val encoder = encoderFor( + fieldType, + seenTypeSet + t, + path.recordField(getClassNameFromType(fieldType), fieldName), + isRowEncoderSupported) + EncoderField(fieldName, encoder, encoder.nullable, Metadata.empty) } val cls = getClassFromType(t) ProductEncoder(ClassTag(cls), params, Option(OuterScopes.getOuterScope(cls))) @@ -404,10 +417,11 @@ object ScalaReflection extends ScalaReflection { } /** - * Support for generating catalyst schemas for scala objects. Note that unlike its companion + * Support for generating catalyst schemas for scala objects. Note that unlike its companion * object, this trait able to work in both the runtime and the compile time (macro) universe. */ trait ScalaReflection extends Logging { + /** The universe we work in (runtime or macro) */ val universe: scala.reflect.api.Universe @@ -421,7 +435,8 @@ trait ScalaReflection extends Logging { * clean up the Scala reflection garbage automatically. Otherwise, it will leak some objects to * `scala.reflect.runtime.JavaUniverse.undoLog`. * - * @see https://github.com/scala/bug/issues/8302 + * @see + * https://github.com/scala/bug/issues/8302 */ def cleanUpReflectionObjects[T](func: => T): T = { universe.asInstanceOf[scala.reflect.runtime.JavaUniverse].undoLog.undo(func) @@ -430,12 +445,13 @@ trait ScalaReflection extends Logging { /** * Return the Scala Type for `T` in the current classloader mirror. * - * Use this method instead of the convenience method `universe.typeOf`, which - * assumes that all types can be found in the classloader that loaded scala-reflect classes. - * That's not necessarily the case when running using Eclipse launchers or even - * Sbt console or test (without `fork := true`). + * Use this method instead of the convenience method `universe.typeOf`, which assumes that all + * types can be found in the classloader that loaded scala-reflect classes. That's not + * necessarily the case when running using Eclipse launchers or even Sbt console or test + * (without `fork := true`). * - * @see SPARK-5281 + * @see + * SPARK-5281 */ def localTypeOf[T: TypeTag]: `Type` = { val tag = implicitly[TypeTag[T]] @@ -474,8 +490,8 @@ trait ScalaReflection extends Logging { } /** - * If our type is a Scala trait it may have a companion object that - * only defines a constructor via `apply` method. + * If our type is a Scala trait it may have a companion object that only defines a constructor + * via `apply` method. */ private def getCompanionConstructor(tpe: Type): Symbol = { def throwUnsupportedOperation = { @@ -483,10 +499,11 @@ trait ScalaReflection extends Logging { } tpe.typeSymbol.asClass.companion match { case NoSymbol => throwUnsupportedOperation - case sym => sym.asTerm.typeSignature.member(universe.TermName("apply")) match { - case NoSymbol => throwUnsupportedOperation - case constructorSym => constructorSym - } + case sym => + sym.asTerm.typeSignature.member(universe.TermName("apply")) match { + case NoSymbol => throwUnsupportedOperation + case constructorSym => constructorSym + } } } @@ -499,8 +516,9 @@ trait ScalaReflection extends Logging { constructorSymbol.asMethod.paramLists } else { // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find( - s => s.isMethod && s.asMethod.isPrimaryConstructor) + val primaryConstructorSymbol: Option[Symbol] = + constructorSymbol.asTerm.alternatives.find(s => + s.isMethod && s.asMethod.isPrimaryConstructor) if (primaryConstructorSymbol.isEmpty) { throw ExecutionErrors.primaryConstructorNotFoundError(tpe.getClass) } else { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala index cbf1f01344c92..a81c071295e2c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst /** - * This class records the paths the serializer and deserializer walk through to reach current path. - * Note that this class adds new path in prior to recorded paths so it maintains - * the paths as reverse order. + * This class records the paths the serializer and deserializer walk through to reach current + * path. Note that this class adds new path in prior to recorded paths so it maintains the paths + * as reverse order. */ case class WalkedTypePath(private val walkedPaths: Seq[String] = Nil) { def recordRoot(className: String): WalkedTypePath = @@ -33,7 +33,8 @@ case class WalkedTypePath(private val walkedPaths: Seq[String] = Nil) { newInstance(s"""- array element class: "$elementClassName"""") def recordMap(keyClassName: String, valueClassName: String): WalkedTypePath = { - newInstance(s"""- map key class: "$keyClassName"""" + + newInstance( + s"""- map key class: "$keyClassName"""" + s""", value class: "$valueClassName"""") } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala index 9955f1b7bd301..913881f326c90 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala @@ -20,20 +20,17 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.QuotingUtils.quoted - /** - * Thrown by a catalog when an item already exists. The analyzer will rethrow the exception - * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. + * Thrown by a catalog when an item already exists. The analyzer will rethrow the exception as an + * [[org.apache.spark.sql.AnalysisException]] with the correct position information. */ case class NonEmptyNamespaceException( namespace: Array[String], details: String, override val cause: Option[Throwable] = None) - extends AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_3103", - messageParameters = Map( - "namespace" -> quoted(namespace), - "details" -> details)) { + extends AnalysisException( + errorClass = "_LEGACY_ERROR_TEMP_3103", + messageParameters = Map("namespace" -> quoted(namespace), "details" -> details)) { def this(namespace: Array[String]) = this(namespace, "", None) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/SqlApiAnalysis.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/SqlApiAnalysis.scala index 9f5a5b8875b33..f218a12209d61 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/SqlApiAnalysis.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/SqlApiAnalysis.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.analysis object SqlApiAnalysis { + /** - * Resolver should return true if the first string refers to the same entity as the second string. - * For example, by using case insensitive equality. + * Resolver should return true if the first string refers to the same entity as the second + * string. For example, by using case insensitive equality. */ type Resolver = (String, String) => Boolean } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/alreadyExistException.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/alreadyExistException.scala index fae3711baf706..0c667dd8916ee 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/alreadyExistException.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/alreadyExistException.scala @@ -25,21 +25,21 @@ import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.util.ArrayImplicits._ /** - * Thrown by a catalog when an item already exists. The analyzer will rethrow the exception - * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. + * Thrown by a catalog when an item already exists. The analyzer will rethrow the exception as an + * [[org.apache.spark.sql.AnalysisException]] with the correct position information. */ class DatabaseAlreadyExistsException(db: String) - extends NamespaceAlreadyExistsException(Array(db)) + extends NamespaceAlreadyExistsException(Array(db)) // any changes to this class should be backward compatible as it may be used by external connectors -class NamespaceAlreadyExistsException private( +class NamespaceAlreadyExistsException private ( message: String, errorClass: Option[String], messageParameters: Map[String, String]) - extends AnalysisException( - message, - errorClass = errorClass, - messageParameters = messageParameters) { + extends AnalysisException( + message, + errorClass = errorClass, + messageParameters = messageParameters) { def this(errorClass: String, messageParameters: Map[String, String]) = { this( @@ -49,24 +49,28 @@ class NamespaceAlreadyExistsException private( } def this(namespace: Array[String]) = { - this(errorClass = "SCHEMA_ALREADY_EXISTS", + this( + errorClass = "SCHEMA_ALREADY_EXISTS", Map("schemaName" -> quoteNameParts(namespace.toImmutableArraySeq))) } } // any changes to this class should be backward compatible as it may be used by external connectors -class TableAlreadyExistsException private( +class TableAlreadyExistsException private ( message: String, cause: Option[Throwable], errorClass: Option[String], messageParameters: Map[String, String]) - extends AnalysisException( - message, - cause = cause, - errorClass = errorClass, - messageParameters = messageParameters) { + extends AnalysisException( + message, + cause = cause, + errorClass = errorClass, + messageParameters = messageParameters) { - def this(errorClass: String, messageParameters: Map[String, String], cause: Option[Throwable]) = { + def this( + errorClass: String, + messageParameters: Map[String, String], + cause: Option[Throwable]) = { this( SparkThrowableHelper.getMessage(errorClass, messageParameters), cause, @@ -75,47 +79,53 @@ class TableAlreadyExistsException private( } def this(db: String, table: String) = { - this(errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", - messageParameters = Map("relationName" -> - (quoteIdentifier(db) + "." + quoteIdentifier(table))), + this( + errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", + messageParameters = Map( + "relationName" -> + (quoteIdentifier(db) + "." + quoteIdentifier(table))), cause = None) } def this(table: String) = { - this(errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", - messageParameters = Map("relationName" -> - quoteNameParts(AttributeNameParser.parseAttributeName(table))), + this( + errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", + messageParameters = Map( + "relationName" -> + quoteNameParts(AttributeNameParser.parseAttributeName(table))), cause = None) } def this(table: Seq[String]) = { - this(errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", + this( + errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", messageParameters = Map("relationName" -> quoteNameParts(table)), cause = None) } def this(tableIdent: Identifier) = { - this(errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", + this( + errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", messageParameters = Map("relationName" -> quoted(tableIdent)), cause = None) } } -class TempTableAlreadyExistsException private( +class TempTableAlreadyExistsException private ( message: String, cause: Option[Throwable], errorClass: Option[String], messageParameters: Map[String, String]) - extends AnalysisException( - message, - cause = cause, - errorClass = errorClass, - messageParameters = messageParameters) { + extends AnalysisException( + message, + cause = cause, + errorClass = errorClass, + messageParameters = messageParameters) { def this( - errorClass: String, - messageParameters: Map[String, String], - cause: Option[Throwable] = None) = { + errorClass: String, + messageParameters: Map[String, String], + cause: Option[Throwable] = None) = { this( SparkThrowableHelper.getMessage(errorClass, messageParameters), cause, @@ -126,27 +136,31 @@ class TempTableAlreadyExistsException private( def this(table: String) = { this( errorClass = "TEMP_TABLE_OR_VIEW_ALREADY_EXISTS", - messageParameters = Map("relationName" - -> quoteNameParts(AttributeNameParser.parseAttributeName(table)))) + messageParameters = Map( + "relationName" + -> quoteNameParts(AttributeNameParser.parseAttributeName(table)))) } } // any changes to this class should be backward compatible as it may be used by external connectors class ViewAlreadyExistsException(errorClass: String, messageParameters: Map[String, String]) - extends AnalysisException(errorClass, messageParameters) { + extends AnalysisException(errorClass, messageParameters) { def this(ident: Identifier) = - this(errorClass = "VIEW_ALREADY_EXISTS", + this( + errorClass = "VIEW_ALREADY_EXISTS", messageParameters = Map("relationName" -> quoted(ident))) } // any changes to this class should be backward compatible as it may be used by external connectors class FunctionAlreadyExistsException(errorClass: String, messageParameters: Map[String, String]) - extends AnalysisException(errorClass, messageParameters) { + extends AnalysisException(errorClass, messageParameters) { def this(function: Seq[String]) = { - this (errorClass = "ROUTINE_ALREADY_EXISTS", - Map("routineName" -> quoteNameParts(function), + this( + errorClass = "ROUTINE_ALREADY_EXISTS", + Map( + "routineName" -> quoteNameParts(function), "newRoutineType" -> "routine", "existingRoutineType" -> "routine")) } @@ -157,16 +171,16 @@ class FunctionAlreadyExistsException(errorClass: String, messageParameters: Map[ } // any changes to this class should be backward compatible as it may be used by external connectors -class IndexAlreadyExistsException private( +class IndexAlreadyExistsException private ( message: String, cause: Option[Throwable], errorClass: Option[String], messageParameters: Map[String, String]) - extends AnalysisException( - message, - cause = cause, - errorClass = errorClass, - messageParameters = messageParameters) { + extends AnalysisException( + message, + cause = cause, + errorClass = errorClass, + messageParameters = messageParameters) { def this( errorClass: String, diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala index 8977d0be24d77..dbc7622c761e7 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala @@ -24,21 +24,24 @@ import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.util.ArrayImplicits._ /** - * Thrown by a catalog when an item cannot be found. The analyzer will rethrow the exception - * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. + * Thrown by a catalog when an item cannot be found. The analyzer will rethrow the exception as an + * [[org.apache.spark.sql.AnalysisException]] with the correct position information. */ -class NoSuchDatabaseException private[analysis]( +class NoSuchDatabaseException private[analysis] ( message: String, cause: Option[Throwable], errorClass: Option[String], messageParameters: Map[String, String]) - extends AnalysisException( - message, - cause = cause, - errorClass = errorClass, - messageParameters = messageParameters) { + extends AnalysisException( + message, + cause = cause, + errorClass = errorClass, + messageParameters = messageParameters) { - def this(errorClass: String, messageParameters: Map[String, String], cause: Option[Throwable]) = { + def this( + errorClass: String, + messageParameters: Map[String, String], + cause: Option[Throwable]) = { this( SparkThrowableHelper.getMessage(errorClass, messageParameters), cause = cause, @@ -55,16 +58,16 @@ class NoSuchDatabaseException private[analysis]( } // any changes to this class should be backward compatible as it may be used by external connectors -class NoSuchNamespaceException private( +class NoSuchNamespaceException private ( message: String, cause: Option[Throwable], errorClass: Option[String], messageParameters: Map[String, String]) - extends NoSuchDatabaseException( - message, - cause = cause, - errorClass = errorClass, - messageParameters = messageParameters) { + extends NoSuchDatabaseException( + message, + cause = cause, + errorClass = errorClass, + messageParameters = messageParameters) { def this(errorClass: String, messageParameters: Map[String, String]) = { this( @@ -75,29 +78,32 @@ class NoSuchNamespaceException private( } def this(namespace: Seq[String]) = { - this(errorClass = "SCHEMA_NOT_FOUND", - Map("schemaName" -> quoteNameParts(namespace))) + this(errorClass = "SCHEMA_NOT_FOUND", Map("schemaName" -> quoteNameParts(namespace))) } def this(namespace: Array[String]) = { - this(errorClass = "SCHEMA_NOT_FOUND", + this( + errorClass = "SCHEMA_NOT_FOUND", Map("schemaName" -> quoteNameParts(namespace.toImmutableArraySeq))) } } // any changes to this class should be backward compatible as it may be used by external connectors -class NoSuchTableException private( +class NoSuchTableException private ( message: String, cause: Option[Throwable], errorClass: Option[String], messageParameters: Map[String, String]) - extends AnalysisException( - message, - cause = cause, - errorClass = errorClass, - messageParameters = messageParameters) { + extends AnalysisException( + message, + cause = cause, + errorClass = errorClass, + messageParameters = messageParameters) { - def this(errorClass: String, messageParameters: Map[String, String], cause: Option[Throwable]) = { + def this( + errorClass: String, + messageParameters: Map[String, String], + cause: Option[Throwable]) = { this( SparkThrowableHelper.getMessage(errorClass, messageParameters), cause = cause, @@ -108,12 +114,13 @@ class NoSuchTableException private( def this(db: String, table: String) = { this( errorClass = "TABLE_OR_VIEW_NOT_FOUND", - messageParameters = Map("relationName" -> - (quoteIdentifier(db) + "." + quoteIdentifier(table))), + messageParameters = Map( + "relationName" -> + (quoteIdentifier(db) + "." + quoteIdentifier(table))), cause = None) } - def this(name : Seq[String]) = { + def this(name: Seq[String]) = { this( errorClass = "TABLE_OR_VIEW_NOT_FOUND", messageParameters = Map("relationName" -> quoteNameParts(name)), @@ -130,28 +137,28 @@ class NoSuchTableException private( // any changes to this class should be backward compatible as it may be used by external connectors class NoSuchViewException(errorClass: String, messageParameters: Map[String, String]) - extends AnalysisException(errorClass, messageParameters) { + extends AnalysisException(errorClass, messageParameters) { def this(ident: Identifier) = - this(errorClass = "VIEW_NOT_FOUND", - messageParameters = Map("relationName" -> quoted(ident))) + this(errorClass = "VIEW_NOT_FOUND", messageParameters = Map("relationName" -> quoted(ident))) } class NoSuchPermanentFunctionException(db: String, func: String) - extends AnalysisException(errorClass = "ROUTINE_NOT_FOUND", - Map("routineName" -> (quoteIdentifier(db) + "." + quoteIdentifier(func)))) + extends AnalysisException( + errorClass = "ROUTINE_NOT_FOUND", + Map("routineName" -> (quoteIdentifier(db) + "." + quoteIdentifier(func)))) // any changes to this class should be backward compatible as it may be used by external connectors -class NoSuchFunctionException private( +class NoSuchFunctionException private ( message: String, cause: Option[Throwable], errorClass: Option[String], messageParameters: Map[String, String]) - extends AnalysisException( - message, - cause = cause, - errorClass = errorClass, - messageParameters = messageParameters) { + extends AnalysisException( + message, + cause = cause, + errorClass = errorClass, + messageParameters = messageParameters) { def this(errorClass: String, messageParameters: Map[String, String]) = { this( @@ -162,7 +169,8 @@ class NoSuchFunctionException private( } def this(db: String, func: String) = { - this(errorClass = "ROUTINE_NOT_FOUND", + this( + errorClass = "ROUTINE_NOT_FOUND", Map("routineName" -> (quoteIdentifier(db) + "." + quoteIdentifier(func)))) } @@ -172,19 +180,19 @@ class NoSuchFunctionException private( } class NoSuchTempFunctionException(func: String) - extends AnalysisException(errorClass = "ROUTINE_NOT_FOUND", Map("routineName" -> s"`$func`")) + extends AnalysisException(errorClass = "ROUTINE_NOT_FOUND", Map("routineName" -> s"`$func`")) // any changes to this class should be backward compatible as it may be used by external connectors -class NoSuchIndexException private( +class NoSuchIndexException private ( message: String, cause: Option[Throwable], errorClass: Option[String], messageParameters: Map[String, String]) - extends AnalysisException( - message, - cause = cause, - errorClass = errorClass, - messageParameters = messageParameters) { + extends AnalysisException( + message, + cause = cause, + errorClass = errorClass, + messageParameters = messageParameters) { def this( errorClass: String, @@ -201,3 +209,14 @@ class NoSuchIndexException private( this("INDEX_NOT_FOUND", Map("indexName" -> indexName, "tableName" -> tableName), cause) } } + +class CannotReplaceMissingTableException( + tableIdentifier: Identifier, + cause: Option[Throwable] = None) + extends AnalysisException( + errorClass = "TABLE_OR_VIEW_NOT_FOUND", + messageParameters = Map( + "relationName" + -> quoteNameParts( + (tableIdentifier.namespace :+ tableIdentifier.name).toImmutableArraySeq)), + cause = cause) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala index 639b23f714149..10f734b3f84ed 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala @@ -19,24 +19,23 @@ package org.apache.spark.sql.catalyst.encoders import java.{sql => jsql} import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInt} import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} -import java.util.concurrent.ConcurrentHashMap import scala.reflect.{classTag, ClassTag} import org.apache.spark.sql.{Encoder, Row} +import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal} import org.apache.spark.util.SparkClassUtils /** - * A non implementation specific encoder. This encoder containers all the information needed - * to generate an implementation specific encoder (e.g. InternalRow <=> Custom Object). + * A non implementation specific encoder. This encoder containers all the information needed to + * generate an implementation specific encoder (e.g. InternalRow <=> Custom Object). * * The input of the serialization does not need to match the external type of the encoder. This is - * called lenient serialization. An example of this is lenient date serialization, in this case both - * [[java.sql.Date]] and [[java.time.LocalDate]] are allowed. Deserialization is never lenient; it - * will always produce instance of the external type. - * + * called lenient serialization. An example of this is lenient date serialization, in this case + * both [[java.sql.Date]] and [[java.time.LocalDate]] are allowed. Deserialization is never + * lenient; it will always produce instance of the external type. */ trait AgnosticEncoder[T] extends Encoder[T] { def isPrimitive: Boolean @@ -47,16 +46,29 @@ trait AgnosticEncoder[T] extends Encoder[T] { def isStruct: Boolean = false } +/** + * Extract an [[AgnosticEncoder]] from an [[Encoder]]. + */ +trait ToAgnosticEncoder[T] { + def encoder: AgnosticEncoder[T] +} + object AgnosticEncoders { + def agnosticEncoderFor[T: Encoder]: AgnosticEncoder[T] = implicitly[Encoder[T]] match { + case a: AgnosticEncoder[T] => a + case e: ToAgnosticEncoder[T @unchecked] => e.encoder + case other => throw ExecutionErrors.invalidAgnosticEncoderError(other) + } + case class OptionEncoder[E](elementEncoder: AgnosticEncoder[E]) - extends AgnosticEncoder[Option[E]] { + extends AgnosticEncoder[Option[E]] { override def isPrimitive: Boolean = false override def dataType: DataType = elementEncoder.dataType override val clsTag: ClassTag[Option[E]] = ClassTag(classOf[Option[E]]) } case class ArrayEncoder[E](element: AgnosticEncoder[E], containsNull: Boolean) - extends AgnosticEncoder[Array[E]] { + extends AgnosticEncoder[Array[E]] { override def isPrimitive: Boolean = false override def dataType: DataType = ArrayType(element.dataType, containsNull) override val clsTag: ClassTag[Array[E]] = element.clsTag.wrap @@ -73,7 +85,7 @@ object AgnosticEncoders { element: AgnosticEncoder[E], containsNull: Boolean, override val lenientSerialization: Boolean) - extends AgnosticEncoder[C] { + extends AgnosticEncoder[C] { override def isPrimitive: Boolean = false override val dataType: DataType = ArrayType(element.dataType, containsNull) } @@ -83,12 +95,10 @@ object AgnosticEncoders { keyEncoder: AgnosticEncoder[K], valueEncoder: AgnosticEncoder[V], valueContainsNull: Boolean) - extends AgnosticEncoder[C] { + extends AgnosticEncoder[C] { override def isPrimitive: Boolean = false - override val dataType: DataType = MapType( - keyEncoder.dataType, - valueEncoder.dataType, - valueContainsNull) + override val dataType: DataType = + MapType(keyEncoder.dataType, valueEncoder.dataType, valueContainsNull) } case class EncoderField( @@ -114,17 +124,28 @@ object AgnosticEncoders { case class ProductEncoder[K]( override val clsTag: ClassTag[K], override val fields: Seq[EncoderField], - outerPointerGetter: Option[() => AnyRef]) extends StructEncoder[K] + outerPointerGetter: Option[() => AnyRef]) + extends StructEncoder[K] object ProductEncoder { - val cachedCls = new ConcurrentHashMap[Int, Class[_]] - private[sql] def tuple(encoders: Seq[AgnosticEncoder[_]]): AgnosticEncoder[_] = { - val fields = encoders.zipWithIndex.map { - case (e, id) => EncoderField(s"_${id + 1}", e, e.nullable, Metadata.empty) + private val MAX_TUPLE_ELEMENTS = 22 + + private val tupleClassTags = Array.tabulate[ClassTag[Any]](MAX_TUPLE_ELEMENTS + 1) { + case 0 => null + case i => ClassTag(SparkClassUtils.classForName(s"scala.Tuple$i")) + } + + private[sql] def tuple( + encoders: Seq[AgnosticEncoder[_]], + elementsCanBeNull: Boolean = false): AgnosticEncoder[_] = { + val numElements = encoders.size + if (numElements < 1 || numElements > MAX_TUPLE_ELEMENTS) { + throw ExecutionErrors.elementsOfTupleExceedLimitError() } - val cls = cachedCls.computeIfAbsent(encoders.size, - _ => SparkClassUtils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")) - ProductEncoder[Any](ClassTag(cls), fields, None) + val fields = encoders.zipWithIndex.map { case (e, id) => + EncoderField(s"_${id + 1}", e, e.nullable || elementsCanBeNull, Metadata.empty) + } + ProductEncoder[Any](tupleClassTags(numElements), fields, None) } private[sql] def isTuple(tag: ClassTag[_]): Boolean = { @@ -141,19 +162,19 @@ object AgnosticEncoders { object UnboundRowEncoder extends BaseRowEncoder { override val schema: StructType = new StructType() override val fields: Seq[EncoderField] = Seq.empty -} + } case class JavaBeanEncoder[K]( override val clsTag: ClassTag[K], override val fields: Seq[EncoderField]) - extends StructEncoder[K] + extends StructEncoder[K] // This will only work for encoding from/to Sparks' InternalRow format. // It is here for compatibility. case class UDTEncoder[E >: Null]( udt: UserDefinedType[E], udtClass: Class[_ <: UserDefinedType[_]]) - extends AgnosticEncoder[E] { + extends AgnosticEncoder[E] { override def isPrimitive: Boolean = false override def dataType: DataType = udt override def clsTag: ClassTag[E] = ClassTag(udt.userClass) @@ -164,21 +185,19 @@ object AgnosticEncoders { override def isPrimitive: Boolean = false override def dataType: DataType = StringType } - case class ScalaEnumEncoder[T, E]( - parent: Class[T], - override val clsTag: ClassTag[E]) - extends EnumEncoder[E] + case class ScalaEnumEncoder[T, E](parent: Class[T], override val clsTag: ClassTag[E]) + extends EnumEncoder[E] case class JavaEnumEncoder[E](override val clsTag: ClassTag[E]) extends EnumEncoder[E] - protected abstract class LeafEncoder[E : ClassTag](override val dataType: DataType) - extends AgnosticEncoder[E] { + protected abstract class LeafEncoder[E: ClassTag](override val dataType: DataType) + extends AgnosticEncoder[E] { override val clsTag: ClassTag[E] = classTag[E] override val isPrimitive: Boolean = clsTag.runtimeClass.isPrimitive } // Primitive encoders - abstract class PrimitiveLeafEncoder[E : ClassTag](dataType: DataType) - extends LeafEncoder[E](dataType) + abstract class PrimitiveLeafEncoder[E: ClassTag](dataType: DataType) + extends LeafEncoder[E](dataType) case object PrimitiveBooleanEncoder extends PrimitiveLeafEncoder[Boolean](BooleanType) case object PrimitiveByteEncoder extends PrimitiveLeafEncoder[Byte](ByteType) case object PrimitiveShortEncoder extends PrimitiveLeafEncoder[Short](ShortType) @@ -188,24 +207,24 @@ object AgnosticEncoders { case object PrimitiveDoubleEncoder extends PrimitiveLeafEncoder[Double](DoubleType) // Primitive wrapper encoders. - abstract class BoxedLeafEncoder[E : ClassTag, P]( + abstract class BoxedLeafEncoder[E: ClassTag, P]( dataType: DataType, val primitive: PrimitiveLeafEncoder[P]) - extends LeafEncoder[E](dataType) + extends LeafEncoder[E](dataType) case object BoxedBooleanEncoder - extends BoxedLeafEncoder[java.lang.Boolean, Boolean](BooleanType, PrimitiveBooleanEncoder) + extends BoxedLeafEncoder[java.lang.Boolean, Boolean](BooleanType, PrimitiveBooleanEncoder) case object BoxedByteEncoder - extends BoxedLeafEncoder[java.lang.Byte, Byte](ByteType, PrimitiveByteEncoder) + extends BoxedLeafEncoder[java.lang.Byte, Byte](ByteType, PrimitiveByteEncoder) case object BoxedShortEncoder - extends BoxedLeafEncoder[java.lang.Short, Short](ShortType, PrimitiveShortEncoder) + extends BoxedLeafEncoder[java.lang.Short, Short](ShortType, PrimitiveShortEncoder) case object BoxedIntEncoder - extends BoxedLeafEncoder[java.lang.Integer, Int](IntegerType, PrimitiveIntEncoder) + extends BoxedLeafEncoder[java.lang.Integer, Int](IntegerType, PrimitiveIntEncoder) case object BoxedLongEncoder - extends BoxedLeafEncoder[java.lang.Long, Long](LongType, PrimitiveLongEncoder) + extends BoxedLeafEncoder[java.lang.Long, Long](LongType, PrimitiveLongEncoder) case object BoxedFloatEncoder - extends BoxedLeafEncoder[java.lang.Float, Float](FloatType, PrimitiveFloatEncoder) + extends BoxedLeafEncoder[java.lang.Float, Float](FloatType, PrimitiveFloatEncoder) case object BoxedDoubleEncoder - extends BoxedLeafEncoder[java.lang.Double, Double](DoubleType, PrimitiveDoubleEncoder) + extends BoxedLeafEncoder[java.lang.Double, Double](DoubleType, PrimitiveDoubleEncoder) // Nullable leaf encoders case object NullEncoder extends LeafEncoder[java.lang.Void](NullType) @@ -218,19 +237,19 @@ object AgnosticEncoders { case object YearMonthIntervalEncoder extends LeafEncoder[Period](YearMonthIntervalType()) case object VariantEncoder extends LeafEncoder[VariantVal](VariantType) case class DateEncoder(override val lenientSerialization: Boolean) - extends LeafEncoder[jsql.Date](DateType) + extends LeafEncoder[jsql.Date](DateType) case class LocalDateEncoder(override val lenientSerialization: Boolean) - extends LeafEncoder[LocalDate](DateType) + extends LeafEncoder[LocalDate](DateType) case class TimestampEncoder(override val lenientSerialization: Boolean) - extends LeafEncoder[jsql.Timestamp](TimestampType) + extends LeafEncoder[jsql.Timestamp](TimestampType) case class InstantEncoder(override val lenientSerialization: Boolean) - extends LeafEncoder[Instant](TimestampType) + extends LeafEncoder[Instant](TimestampType) case object LocalDateTimeEncoder extends LeafEncoder[LocalDateTime](TimestampNTZType) case class SparkDecimalEncoder(dt: DecimalType) extends LeafEncoder[Decimal](dt) case class ScalaDecimalEncoder(dt: DecimalType) extends LeafEncoder[BigDecimal](dt) case class JavaDecimalEncoder(dt: DecimalType, override val lenientSerialization: Boolean) - extends LeafEncoder[JBigDecimal](dt) + extends LeafEncoder[JBigDecimal](dt) val STRICT_DATE_ENCODER: DateEncoder = DateEncoder(lenientSerialization = false) val STRICT_LOCAL_DATE_ENCODER: LocalDateEncoder = LocalDateEncoder(lenientSerialization = false) @@ -257,7 +276,8 @@ object AgnosticEncoders { case class TransformingEncoder[I, O]( clsTag: ClassTag[I], transformed: AgnosticEncoder[O], - codecProvider: () => Codec[_ >: I, O]) extends AgnosticEncoder[I] { + codecProvider: () => Codec[_ >: I, O]) + extends AgnosticEncoder[I] { override def isPrimitive: Boolean = transformed.isPrimitive override def dataType: DataType = transformed.dataType override def schema: StructType = transformed.schema diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala index 8587688956950..909556492847f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala @@ -48,10 +48,9 @@ object OuterScopes { /** * Adds a new outer scope to this context that can be used when instantiating an `inner class` - * during deserialization. Inner classes are created when a case class is defined in the - * Spark REPL and registering the outer scope that this class was defined in allows us to create - * new instances on the spark executors. In normal use, users should not need to call this - * function. + * during deserialization. Inner classes are created when a case class is defined in the Spark + * REPL and registering the outer scope that this class was defined in allows us to create new + * instances on the spark executors. In normal use, users should not need to call this function. * * Warning: this function operates on the assumption that there is only ever one instance of any * given wrapper class. @@ -65,7 +64,7 @@ object OuterScopes { } /** - * Returns a function which can get the outer scope for the given inner class. By using function + * Returns a function which can get the outer scope for the given inner class. By using function * as return type, we can delay the process of getting outer pointer to execution time, which is * useful for inner class defined in REPL. */ @@ -134,8 +133,8 @@ object OuterScopes { } case _ => null } - } else { - () => outer + } else { () => + outer } } @@ -162,7 +161,7 @@ object OuterScopes { * dead entries after GC (using a [[ReferenceQueue]]). */ private[catalyst] class HashableWeakReference(v: AnyRef, queue: ReferenceQueue[AnyRef]) - extends WeakReference[AnyRef](v, queue) { + extends WeakReference[AnyRef](v, queue) { def this(v: AnyRef) = this(v, null) private[this] val hash = v.hashCode() override def hashCode(): Int = hash diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index c507e952630f6..8b6da805a6e87 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ /** - * A factory for constructing encoders that convert external row to/from the Spark SQL - * internal binary representation. + * A factory for constructing encoders that convert external row to/from the Spark SQL internal + * binary representation. * * The following is a mapping between Spark SQL types and its allowed external types: * {{{ @@ -68,67 +68,65 @@ object RowEncoder extends DataTypeErrorsBase { encoderForDataType(schema, lenient).asInstanceOf[AgnosticEncoder[Row]] } - private[sql] def encoderForDataType( - dataType: DataType, - lenient: Boolean): AgnosticEncoder[_] = dataType match { - case NullType => NullEncoder - case BooleanType => BoxedBooleanEncoder - case ByteType => BoxedByteEncoder - case ShortType => BoxedShortEncoder - case IntegerType => BoxedIntEncoder - case LongType => BoxedLongEncoder - case FloatType => BoxedFloatEncoder - case DoubleType => BoxedDoubleEncoder - case dt: DecimalType => JavaDecimalEncoder(dt, lenientSerialization = true) - case BinaryType => BinaryEncoder - case _: StringType => StringEncoder - case TimestampType if SqlApiConf.get.datetimeJava8ApiEnabled => InstantEncoder(lenient) - case TimestampType => TimestampEncoder(lenient) - case TimestampNTZType => LocalDateTimeEncoder - case DateType if SqlApiConf.get.datetimeJava8ApiEnabled => LocalDateEncoder(lenient) - case DateType => DateEncoder(lenient) - case CalendarIntervalType => CalendarIntervalEncoder - case _: DayTimeIntervalType => DayTimeIntervalEncoder - case _: YearMonthIntervalType => YearMonthIntervalEncoder - case _: VariantType => VariantEncoder - case p: PythonUserDefinedType => - // TODO check if this works. - encoderForDataType(p.sqlType, lenient) - case udt: UserDefinedType[_] => - val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) - val udtClass: Class[_] = if (annotation != null) { - annotation.udt() - } else { - UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse { - throw ExecutionErrors.userDefinedTypeNotAnnotatedAndRegisteredError(udt) + private[sql] def encoderForDataType(dataType: DataType, lenient: Boolean): AgnosticEncoder[_] = + dataType match { + case NullType => NullEncoder + case BooleanType => BoxedBooleanEncoder + case ByteType => BoxedByteEncoder + case ShortType => BoxedShortEncoder + case IntegerType => BoxedIntEncoder + case LongType => BoxedLongEncoder + case FloatType => BoxedFloatEncoder + case DoubleType => BoxedDoubleEncoder + case dt: DecimalType => JavaDecimalEncoder(dt, lenientSerialization = true) + case BinaryType => BinaryEncoder + case _: StringType => StringEncoder + case TimestampType if SqlApiConf.get.datetimeJava8ApiEnabled => InstantEncoder(lenient) + case TimestampType => TimestampEncoder(lenient) + case TimestampNTZType => LocalDateTimeEncoder + case DateType if SqlApiConf.get.datetimeJava8ApiEnabled => LocalDateEncoder(lenient) + case DateType => DateEncoder(lenient) + case CalendarIntervalType => CalendarIntervalEncoder + case _: DayTimeIntervalType => DayTimeIntervalEncoder + case _: YearMonthIntervalType => YearMonthIntervalEncoder + case _: VariantType => VariantEncoder + case p: PythonUserDefinedType => + // TODO check if this works. + encoderForDataType(p.sqlType, lenient) + case udt: UserDefinedType[_] => + val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) + val udtClass: Class[_] = if (annotation != null) { + annotation.udt() + } else { + UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse { + throw ExecutionErrors.userDefinedTypeNotAnnotatedAndRegisteredError(udt) + } } - } - UDTEncoder(udt, udtClass.asInstanceOf[Class[_ <: UserDefinedType[_]]]) - case ArrayType(elementType, containsNull) => - IterableEncoder( - classTag[mutable.ArraySeq[_]], - encoderForDataType(elementType, lenient), - containsNull, - lenientSerialization = true) - case MapType(keyType, valueType, valueContainsNull) => - MapEncoder( - classTag[scala.collection.Map[_, _]], - encoderForDataType(keyType, lenient), - encoderForDataType(valueType, lenient), - valueContainsNull) - case StructType(fields) => - AgnosticRowEncoder(fields.map { field => - EncoderField( - field.name, - encoderForDataType(field.dataType, lenient), - field.nullable, - field.metadata) - }.toImmutableArraySeq) + UDTEncoder(udt, udtClass.asInstanceOf[Class[_ <: UserDefinedType[_]]]) + case ArrayType(elementType, containsNull) => + IterableEncoder( + classTag[mutable.ArraySeq[_]], + encoderForDataType(elementType, lenient), + containsNull, + lenientSerialization = true) + case MapType(keyType, valueType, valueContainsNull) => + MapEncoder( + classTag[scala.collection.Map[_, _]], + encoderForDataType(keyType, lenient), + encoderForDataType(valueType, lenient), + valueContainsNull) + case StructType(fields) => + AgnosticRowEncoder(fields.map { field => + EncoderField( + field.name, + encoderForDataType(field.dataType, lenient), + field.nullable, + field.metadata) + }.toImmutableArraySeq) - case _ => - throw new AnalysisException( - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_ENCODER", - messageParameters = Map("dataType" -> toSQLType(dataType)) - ) - } + case _ => + throw new AnalysisException( + errorClass = "UNSUPPORTED_DATA_TYPE_FOR_ENCODER", + messageParameters = Map("dataType" -> toSQLType(dataType))) + } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala index 46862ebbccdfd..0f21972552339 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala @@ -16,15 +16,20 @@ */ package org.apache.spark.sql.catalyst.encoders -import org.apache.spark.util.SparkSerDeUtils +import java.lang.invoke.{MethodHandle, MethodHandles, MethodType} + +import org.apache.spark.sql.errors.ExecutionErrors +import org.apache.spark.util.{SparkClassUtils, SparkSerDeUtils} /** * Codec for doing conversions between two representations. * - * @tparam I input type (typically the external representation of the data. - * @tparam O output type (typically the internal representation of the data. + * @tparam I + * input type (typically the external representation of the data. + * @tparam O + * output type (typically the internal representation of the data. */ -trait Codec[I, O] { +trait Codec[I, O] extends Serializable { def encode(in: I): O def decode(out: O): I } @@ -40,3 +45,29 @@ class JavaSerializationCodec[I] extends Codec[I, Array[Byte]] { object JavaSerializationCodec extends (() => Codec[Any, Array[Byte]]) { override def apply(): Codec[Any, Array[Byte]] = new JavaSerializationCodec[Any] } + +/** + * A codec that uses Kryo to (de)serialize arbitrary objects to and from a byte array. + * + * Please note that this is currently only supported for Classic Spark applications. The reason + * for this is that Connect applications can have a significantly different classpath than the + * driver or executor. This makes having a the same Kryo configuration on both the client and + * server (driver & executors) very tricky. As a workaround a user can define their own Codec + * which internalizes the Kryo configuration. + */ +object KryoSerializationCodec extends (() => Codec[Any, Array[Byte]]) { + private lazy val kryoCodecConstructor: MethodHandle = { + val cls = SparkClassUtils.classForName( + "org.apache.spark.sql.catalyst.encoders.KryoSerializationCodecImpl") + MethodHandles.lookup().findConstructor(cls, MethodType.methodType(classOf[Unit])) + } + + override def apply(): Codec[Any, Array[Byte]] = { + try { + kryoCodecConstructor.invoke().asInstanceOf[Codec[Any, Array[Byte]]] + } catch { + case _: ClassNotFoundException => + throw ExecutionErrors.cannotUseKryoSerialization() + } + } +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/OrderUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/OrderUtils.scala index 385e0f00695a3..76442accbd35b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/OrderUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/OrderUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.types.{ArrayType, AtomicType, DataType, NullType, StructType, UserDefinedType, VariantType} object OrderUtils { + /** * Returns true iff the data type can be ordered (i.e. can be sorted). */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 7f21ab25ad4e5..6977d9f3185ac 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -21,11 +21,12 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.util.ArrayImplicits._ /** - * A row implementation that uses an array of objects as the underlying storage. Note that, while + * A row implementation that uses an array of objects as the underlying storage. Note that, while * the array is not copied, and thus could technically be mutated after creation, this is not * allowed. */ class GenericRow(protected[sql] val values: Array[Any]) extends Row { + /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -41,7 +42,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { } class GenericRowWithSchema(values: Array[Any], override val schema: StructType) - extends GenericRow(values) { + extends GenericRow(values) { /** No-arg constructor for serialization. */ protected def this() = this(null, null) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index 38ecd29266db7..46fb4a3397c59 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -120,7 +120,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { override def visitDayTimeIntervalDataType(ctx: DayTimeIntervalDataTypeContext): DataType = { val startStr = ctx.from.getText.toLowerCase(Locale.ROOT) val start = DayTimeIntervalType.stringToField(startStr) - if (ctx.to != null ) { + if (ctx.to != null) { val endStr = ctx.to.getText.toLowerCase(Locale.ROOT) val end = DayTimeIntervalType.stringToField(endStr) if (end <= start) { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserInterface.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserInterface.scala index ab665f360b0a6..1a1a9b01de3b1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserInterface.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserInterface.scala @@ -22,9 +22,10 @@ import org.apache.spark.sql.types.{DataType, StructType} * Interface for [[DataType]] parsing functionality. */ trait DataTypeParserInterface { + /** - * Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list - * of field definitions which will preserve the correct Hive metadata. + * Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list of + * field definitions which will preserve the correct Hive metadata. */ @throws[ParseException]("Text cannot be parsed to a schema") def parseTableSchema(sqlText: String): StructType diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/LegacyTypeStringParser.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/LegacyTypeStringParser.scala index 8ac5939bca944..7d1986e727f79 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/LegacyTypeStringParser.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/LegacyTypeStringParser.scala @@ -23,35 +23,36 @@ import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.types._ /** - * Parser that turns case class strings into datatypes. This is only here to maintain compatibility - * with Parquet files written by Spark 1.1 and below. + * Parser that turns case class strings into datatypes. This is only here to maintain + * compatibility with Parquet files written by Spark 1.1 and below. */ object LegacyTypeStringParser extends RegexParsers { protected lazy val primitiveType: Parser[DataType] = - ( "StringType" ^^^ StringType - | "FloatType" ^^^ FloatType - | "IntegerType" ^^^ IntegerType - | "ByteType" ^^^ ByteType - | "ShortType" ^^^ ShortType - | "DoubleType" ^^^ DoubleType - | "LongType" ^^^ LongType - | "BinaryType" ^^^ BinaryType - | "BooleanType" ^^^ BooleanType - | "DateType" ^^^ DateType - | "DecimalType()" ^^^ DecimalType.USER_DEFAULT - | fixedDecimalType - | "TimestampType" ^^^ TimestampType - ) + ( + "StringType" ^^^ StringType + | "FloatType" ^^^ FloatType + | "IntegerType" ^^^ IntegerType + | "ByteType" ^^^ ByteType + | "ShortType" ^^^ ShortType + | "DoubleType" ^^^ DoubleType + | "LongType" ^^^ LongType + | "BinaryType" ^^^ BinaryType + | "BooleanType" ^^^ BooleanType + | "DateType" ^^^ DateType + | "DecimalType()" ^^^ DecimalType.USER_DEFAULT + | fixedDecimalType + | "TimestampType" ^^^ TimestampType + ) protected lazy val fixedDecimalType: Parser[DataType] = - ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { - case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { case precision ~ scale => + DecimalType(precision.toInt, scale.toInt) } protected lazy val arrayType: Parser[DataType] = - "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { - case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) + "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { case tpe ~ _ ~ containsNull => + ArrayType(tpe, containsNull) } protected lazy val mapType: Parser[DataType] = @@ -66,21 +67,23 @@ object LegacyTypeStringParser extends RegexParsers { } protected lazy val boolVal: Parser[Boolean] = - ( "true" ^^^ true - | "false" ^^^ false - ) + ( + "true" ^^^ true + | "false" ^^^ false + ) protected lazy val structType: Parser[DataType] = - "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { - case fields => StructType(fields) + "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { case fields => + StructType(fields) } protected lazy val dataType: Parser[DataType] = - ( arrayType - | mapType - | structType - | primitiveType - ) + ( + arrayType + | mapType + | structType + | primitiveType + ) /** * Parses a string representation of a DataType. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/SparkParserErrorStrategy.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/SparkParserErrorStrategy.scala index 99e63d783838f..461d79ec22cf0 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/SparkParserErrorStrategy.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/SparkParserErrorStrategy.scala @@ -32,7 +32,7 @@ class SparkRecognitionException( ctx: ParserRuleContext, val errorClass: Option[String] = None, val messageParameters: Map[String, String] = Map.empty) - extends RecognitionException(message, recognizer, input, ctx) { + extends RecognitionException(message, recognizer, input, ctx) { /** Construct from a given [[RecognitionException]], with additional error information. */ def this( @@ -50,7 +50,7 @@ class SparkRecognitionException( Some(errorClass), messageParameters) - /** Construct with pure errorClass and messageParameter information. */ + /** Construct with pure errorClass and messageParameter information. */ def this(errorClass: String, messageParameters: Map[String, String]) = this("", null, null, null, Some(errorClass), messageParameters) } @@ -59,12 +59,12 @@ class SparkRecognitionException( * A [[SparkParserErrorStrategy]] extends the [[DefaultErrorStrategy]], that does special handling * on errors. * - * The intention of this class is to provide more information of these errors encountered in - * ANTLR parser to the downstream consumers, to be able to apply the [[SparkThrowable]] error - * message framework to these exceptions. + * The intention of this class is to provide more information of these errors encountered in ANTLR + * parser to the downstream consumers, to be able to apply the [[SparkThrowable]] error message + * framework to these exceptions. */ class SparkParserErrorStrategy() extends DefaultErrorStrategy { - private val userWordDict : Map[String, String] = Map("''" -> "end of input") + private val userWordDict: Map[String, String] = Map("''" -> "end of input") /** Get the user-facing display of the error token. */ override def getTokenErrorDisplay(t: Token): String = { @@ -76,9 +76,7 @@ class SparkParserErrorStrategy() extends DefaultErrorStrategy { val exceptionWithErrorClass = new SparkRecognitionException( e, "PARSE_SYNTAX_ERROR", - messageParameters = Map( - "error" -> getTokenErrorDisplay(e.getOffendingToken), - "hint" -> "")) + messageParameters = Map("error" -> getTokenErrorDisplay(e.getOffendingToken), "hint" -> "")) recognizer.notifyErrorListeners(e.getOffendingToken, "", exceptionWithErrorClass) } @@ -116,18 +114,17 @@ class SparkParserErrorStrategy() extends DefaultErrorStrategy { /** * Inspired by [[org.antlr.v4.runtime.BailErrorStrategy]], which is used in two-stage parsing: - * This error strategy allows the first stage of two-stage parsing to immediately terminate - * if an error is encountered, and immediately fall back to the second stage. In addition to - * avoiding wasted work by attempting to recover from errors here, the empty implementation - * of sync improves the performance of the first stage. + * This error strategy allows the first stage of two-stage parsing to immediately terminate if an + * error is encountered, and immediately fall back to the second stage. In addition to avoiding + * wasted work by attempting to recover from errors here, the empty implementation of sync + * improves the performance of the first stage. */ class SparkParserBailErrorStrategy() extends SparkParserErrorStrategy { /** - * Instead of recovering from exception e, re-throw it wrapped - * in a [[ParseCancellationException]] so it is not caught by the - * rule function catches. Use [[Exception#getCause]] to get the - * original [[RecognitionException]]. + * Instead of recovering from exception e, re-throw it wrapped in a + * [[ParseCancellationException]] so it is not caught by the rule function catches. Use + * [[Exception#getCause]] to get the original [[RecognitionException]]. */ override def recover(recognizer: Parser, e: RecognitionException): Unit = { var context = recognizer.getContext @@ -139,8 +136,8 @@ class SparkParserBailErrorStrategy() extends SparkParserErrorStrategy { } /** - * Make sure we don't attempt to recover inline; if the parser - * successfully recovers, it won't throw an exception. + * Make sure we don't attempt to recover inline; if the parser successfully recovers, it won't + * throw an exception. */ @throws[RecognitionException] override def recoverInline(recognizer: Parser): Token = { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala index 0b9e6ea021be1..10da24567545b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.types.{DataType, StructType} * Base SQL parsing infrastructure. */ abstract class AbstractParser extends DataTypeParserInterface with Logging { + /** Creates/Resolves DataType for a given SQL string. */ override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => astBuilder.visitSingleDataType(parser.singleDataType()) @@ -78,8 +79,7 @@ abstract class AbstractParser extends DataTypeParserInterface with Logging { parser.setErrorHandler(new SparkParserBailErrorStrategy()) parser.getInterpreter.setPredictionMode(PredictionMode.SLL) toResult(parser) - } - catch { + } catch { case e: ParseCancellationException => // if we fail, parse with LL mode w/ SparkParserErrorStrategy tokenStream.seek(0) // rewind input stream @@ -90,8 +90,7 @@ abstract class AbstractParser extends DataTypeParserInterface with Logging { parser.getInterpreter.setPredictionMode(PredictionMode.LL) toResult(parser) } - } - catch { + } catch { case e: ParseException if e.command.isDefined => throw e case e: ParseException => @@ -187,7 +186,7 @@ case object ParseErrorListener extends BaseErrorListener { * A [[ParseException]] is an [[SparkException]] that is thrown during the parse process. It * contains fields and an extended error message that make reporting and diagnosing errors easier. */ -class ParseException private( +class ParseException private ( val command: Option[String], message: String, val start: Origin, @@ -195,17 +194,18 @@ class ParseException private( errorClass: Option[String] = None, messageParameters: Map[String, String] = Map.empty, queryContext: Array[QueryContext] = ParseException.getQueryContext()) - extends AnalysisException( - message, - start.line, - start.startPosition, - None, - errorClass, - messageParameters, - queryContext) { + extends AnalysisException( + message, + start.line, + start.startPosition, + None, + errorClass, + messageParameters, + queryContext) { def this(errorClass: String, messageParameters: Map[String, String], ctx: ParserRuleContext) = - this(Option(SparkParserUtils.command(ctx)), + this( + Option(SparkParserUtils.command(ctx)), SparkThrowableHelper.getMessage(errorClass, messageParameters), SparkParserUtils.position(ctx.getStart), SparkParserUtils.position(ctx.getStop), @@ -310,14 +310,16 @@ case object PostProcessor extends SqlBaseParserBaseListener { throw QueryParsingErrors.invalidIdentifierError(ident, ctx) } - /** Throws error message when unquoted identifier contains characters outside a-z, A-Z, 0-9, _ */ + /** + * Throws error message when unquoted identifier contains characters outside a-z, A-Z, 0-9, _ + */ override def exitUnquotedIdentifier(ctx: SqlBaseParser.UnquotedIdentifierContext): Unit = { val ident = ctx.getText if (ident.exists(c => - !(c >= 'a' && c <= 'z') && - !(c >= 'A' && c <= 'Z') && - !(c >= '0' && c <= '9') && - c != '_')) { + !(c >= 'a' && c <= 'z') && + !(c >= 'A' && c <= 'Z') && + !(c >= '0' && c <= '9') && + c != '_')) { throw QueryParsingErrors.invalidIdentifierError(ident, ctx) } } @@ -353,9 +355,7 @@ case object PostProcessor extends SqlBaseParserBaseListener { replaceTokenByIdentifier(ctx, 0)(identity) } - private def replaceTokenByIdentifier( - ctx: ParserRuleContext, - stripMargins: Int)( + private def replaceTokenByIdentifier(ctx: ParserRuleContext, stripMargins: Int)( f: CommonToken => CommonToken = identity): Unit = { val parent = ctx.getParent parent.removeLastChild() @@ -373,8 +373,8 @@ case object PostProcessor extends SqlBaseParserBaseListener { /** * The post-processor checks the unclosed bracketed comment. */ -case class UnclosedCommentProcessor( - command: String, tokenStream: CommonTokenStream) extends SqlBaseParserBaseListener { +case class UnclosedCommentProcessor(command: String, tokenStream: CommonTokenStream) + extends SqlBaseParserBaseListener { override def exitSingleDataType(ctx: SqlBaseParser.SingleDataTypeContext): Unit = { checkUnclosedComment(tokenStream, command) @@ -384,7 +384,8 @@ case class UnclosedCommentProcessor( checkUnclosedComment(tokenStream, command) } - override def exitSingleTableIdentifier(ctx: SqlBaseParser.SingleTableIdentifierContext): Unit = { + override def exitSingleTableIdentifier( + ctx: SqlBaseParser.SingleTableIdentifierContext): Unit = { checkUnclosedComment(tokenStream, command) } @@ -422,7 +423,8 @@ case class UnclosedCommentProcessor( // The last token is 'EOF' and the penultimate is unclosed bracketed comment val failedToken = tokenStream.get(tokenStream.size() - 2) assert(failedToken.getType() == SqlBaseParser.BRACKETED_COMMENT) - val position = Origin(Option(failedToken.getLine), Option(failedToken.getCharPositionInLine)) + val position = + Origin(Option(failedToken.getLine), Option(failedToken.getCharPositionInLine)) throw QueryParsingErrors.unclosedBracketedCommentError( command = command, start = Origin(Option(failedToken.getStartIndex)), diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala index 4c7c87504ffc4..e870a83ec4ae6 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala @@ -37,9 +37,10 @@ object TimeModes { ProcessingTime case "eventtime" => EventTime - case _ => throw new SparkIllegalArgumentException( - errorClass = "STATEFUL_PROCESSOR_UNKNOWN_TIME_MODE", - messageParameters = Map("timeMode" -> timeMode)) + case _ => + throw new SparkIllegalArgumentException( + errorClass = "STATEFUL_PROCESSOR_UNKNOWN_TIME_MODE", + messageParameters = Map("timeMode" -> timeMode)) } } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala index 0ccfc3cbc7bf2..8c54758ebc0bc 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala @@ -28,27 +28,26 @@ import org.apache.spark.sql.streaming.OutputMode private[sql] object InternalOutputModes { /** - * OutputMode in which only the new rows in the streaming DataFrame/Dataset will be - * written to the sink. This output mode can be only be used in queries that do not - * contain any aggregation. + * OutputMode in which only the new rows in the streaming DataFrame/Dataset will be written to + * the sink. This output mode can be only be used in queries that do not contain any + * aggregation. */ case object Append extends OutputMode /** - * OutputMode in which all the rows in the streaming DataFrame/Dataset will be written - * to the sink every time these is some updates. This output mode can only be used in queries - * that contain aggregations. + * OutputMode in which all the rows in the streaming DataFrame/Dataset will be written to the + * sink every time these is some updates. This output mode can only be used in queries that + * contain aggregations. */ case object Complete extends OutputMode /** - * OutputMode in which only the rows in the streaming DataFrame/Dataset that were updated will be - * written to the sink every time these is some updates. If the query doesn't contain + * OutputMode in which only the rows in the streaming DataFrame/Dataset that were updated will + * be written to the sink every time these is some updates. If the query doesn't contain * aggregations, it will be equivalent to `Append` mode. */ case object Update extends OutputMode - def apply(outputMode: String): OutputMode = { outputMode.toLowerCase(Locale.ROOT) match { case "append" => @@ -57,7 +56,8 @@ private[sql] object InternalOutputModes { OutputMode.Complete case "update" => OutputMode.Update - case _ => throw new SparkIllegalArgumentException( + case _ => + throw new SparkIllegalArgumentException( errorClass = "STREAMING_OUTPUT_MODE.INVALID", messageParameters = Map("outputMode" -> outputMode)) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala index 2b3f4674539e3..d5be65a2f36cf 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala @@ -27,7 +27,8 @@ case class SQLQueryContext( originStopIndex: Option[Int], sqlText: Option[String], originObjectType: Option[String], - originObjectName: Option[String]) extends QueryContext { + originObjectName: Option[String]) + extends QueryContext { override val contextType = QueryContextType.SQL override val objectType = originObjectType.getOrElse("") @@ -37,9 +38,8 @@ case class SQLQueryContext( /** * The SQL query context of current node. For example: - * == SQL of VIEW v1 (line 1, position 25) == - * SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i - * ^^^^^^^^^^^^^^^ + * ==SQL of VIEW v1 (line 1, position 25)== + * SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i ^^^^^^^^^^^^^^^ */ override lazy val summary: String = { // If the query context is missing or incorrect, simply return an empty string. @@ -127,8 +127,8 @@ case class SQLQueryContext( def isValid: Boolean = { sqlText.isDefined && originStartIndex.isDefined && originStopIndex.isDefined && - originStartIndex.get >= 0 && originStopIndex.get < sqlText.get.length && - originStartIndex.get <= originStopIndex.get + originStartIndex.get >= 0 && originStopIndex.get < sqlText.get.length && + originStartIndex.get <= originStopIndex.get } override def callSite: String = throw SparkUnsupportedOperationException() @@ -136,7 +136,8 @@ case class SQLQueryContext( case class DataFrameQueryContext( stackTrace: Seq[StackTraceElement], - pysparkErrorContext: Option[(String, String)]) extends QueryContext { + pysparkErrorContext: Option[(String, String)]) + extends QueryContext { override val contextType = QueryContextType.DataFrame override def objectType: String = throw SparkUnsupportedOperationException() @@ -146,19 +147,21 @@ case class DataFrameQueryContext( override val fragment: String = { pysparkErrorContext.map(_._1).getOrElse { - stackTrace.headOption.map { firstElem => - val methodName = firstElem.getMethodName - if (methodName.length > 1 && methodName(0) == '$') { - methodName.substring(1) - } else { - methodName + stackTrace.headOption + .map { firstElem => + val methodName = firstElem.getMethodName + if (methodName.length > 1 && methodName(0) == '$') { + methodName.substring(1) + } else { + methodName + } } - }.getOrElse("") + .getOrElse("") } } - override val callSite: String = pysparkErrorContext.map( - _._2).getOrElse(stackTrace.tail.mkString("\n")) + override val callSite: String = + pysparkErrorContext.map(_._2).getOrElse(stackTrace.tail.mkString("\n")) override lazy val summary: String = { val builder = new StringBuilder diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala index 3e0ebdd627b63..33fa17433abbd 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.util.ArrayImplicits._ /** - * Contexts of TreeNodes, including location, SQL text, object type and object name. - * The only supported object type is "VIEW" now. In the future, we may support SQL UDF or other - * objects which contain SQL text. + * Contexts of TreeNodes, including location, SQL text, object type and object name. The only + * supported object type is "VIEW" now. In the future, we may support SQL UDF or other objects + * which contain SQL text. */ case class Origin( line: Option[Int] = None, @@ -41,8 +41,7 @@ case class Origin( lazy val context: QueryContext = if (stackTrace.isDefined) { DataFrameQueryContext(stackTrace.get.toImmutableArraySeq, pysparkErrorContext) } else { - SQLQueryContext( - line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName) + SQLQueryContext(line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName) } def getQueryContext: Array[QueryContext] = { @@ -61,7 +60,7 @@ trait WithOrigin { } /** - * Provides a location for TreeNodes to ask about the context of their origin. For example, which + * Provides a location for TreeNodes to ask about the context of their origin. For example, which * line of code is currently being parsed. */ object CurrentOrigin { @@ -75,8 +74,7 @@ object CurrentOrigin { def reset(): Unit = value.set(Origin()) def setPosition(line: Int, start: Int): Unit = { - value.set( - value.get.copy(line = Some(line), startPosition = Some(start))) + value.set(value.get.copy(line = Some(line), startPosition = Some(start))) } def withOrigin[A](o: Origin)(f: => A): A = { @@ -84,25 +82,29 @@ object CurrentOrigin { // this way withOrigin can be recursive val previous = get set(o) - val ret = try f finally { set(previous) } + val ret = + try f + finally { set(previous) } ret } /** - * This helper function captures the Spark API and its call site in the user code from the current - * stacktrace. + * This helper function captures the Spark API and its call site in the user code from the + * current stacktrace. * * As adding `withOrigin` explicitly to all Spark API definition would be a huge change, * `withOrigin` is used only at certain places where all API implementation surely pass through - * and the current stacktrace is filtered to the point where first Spark API code is invoked from - * the user code. + * and the current stacktrace is filtered to the point where first Spark API code is invoked + * from the user code. * * As there might be multiple nested `withOrigin` calls (e.g. any Spark API implementations can * invoke other APIs) only the first `withOrigin` is captured because that is closer to the user * code. * - * @param f The function that can use the origin. - * @return The result of `f`. + * @param f + * The function that can use the origin. + * @return + * The result of `f`. */ private[sql] def withOrigin[T](f: => T): T = { if (CurrentOrigin.get.stackTrace.isDefined) { @@ -114,9 +116,9 @@ object CurrentOrigin { while (i < st.length && !sparkCode(st(i))) i += 1 // Stop at the end of the first Spark code traces while (i < st.length && sparkCode(st(i))) i += 1 - val origin = Origin(stackTrace = Some(st.slice( - from = i - 1, - until = i + SqlApiConf.get.stackTracesInDataFrameContext)), + val origin = Origin( + stackTrace = + Some(st.slice(from = i - 1, until = i + SqlApiConf.get.stackTracesInDataFrameContext)), pysparkErrorContext = PySparkCurrentOrigin.get()) CurrentOrigin.withOrigin(origin)(f) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/AttributeNameParser.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/AttributeNameParser.scala index e47ab1978d0ed..533b09e82df13 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/AttributeNameParser.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/AttributeNameParser.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.errors.DataTypeErrors trait AttributeNameParser { + /** - * Used to split attribute name by dot with backticks rule. - * Backticks must appear in pairs, and the quoted string must be a complete name part, - * which means `ab..c`e.f is not allowed. - * We can use backtick only inside quoted name parts. + * Used to split attribute name by dot with backticks rule. Backticks must appear in pairs, and + * the quoted string must be a complete name part, which means `ab..c`e.f is not allowed. We can + * use backtick only inside quoted name parts. */ def parseAttributeName(name: String): Seq[String] = { def e = DataTypeErrors.attributeNameSyntaxError(name) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala index 640304efce4b4..9e90feeb782d6 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala @@ -23,12 +23,12 @@ import java.util.Locale * Builds a map in which keys are case insensitive. Input map can be accessed for cases where * case-sensitive information is required. The primary constructor is marked private to avoid * nested case-insensitive map creation, otherwise the keys in the original map will become - * case-insensitive in this scenario. - * Note: CaseInsensitiveMap is serializable. However, after transformation, e.g. `filterKeys()`, - * it may become not serializable. + * case-insensitive in this scenario. Note: CaseInsensitiveMap is serializable. However, after + * transformation, e.g. `filterKeys()`, it may become not serializable. */ -class CaseInsensitiveMap[T] private (val originalMap: Map[String, T]) extends Map[String, T] - with Serializable { +class CaseInsensitiveMap[T] private (val originalMap: Map[String, T]) + extends Map[String, T] + with Serializable { val keyLowerCasedMap = originalMap.map(kv => kv.copy(_1 = kv._1.toLowerCase(Locale.ROOT))) @@ -62,4 +62,3 @@ object CaseInsensitiveMap { case _ => new CaseInsensitiveMap(params) } } - diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeJsonUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeJsonUtils.scala index e75429c58cc7b..b8ab633b2047f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeJsonUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeJsonUtils.scala @@ -25,15 +25,16 @@ import org.json4s.jackson.{JValueDeserializer, JValueSerializer} import org.apache.spark.sql.types.DataType object DataTypeJsonUtils { + /** * Jackson serializer for [[DataType]]. Internally this delegates to json4s based serialization. */ class DataTypeJsonSerializer extends JsonSerializer[DataType] { private val delegate = new JValueSerializer override def serialize( - value: DataType, - gen: JsonGenerator, - provider: SerializerProvider): Unit = { + value: DataType, + gen: JsonGenerator, + provider: SerializerProvider): Unit = { delegate.serialize(value.jsonValue, gen, provider) } } @@ -46,8 +47,8 @@ object DataTypeJsonUtils { private val delegate = new JValueDeserializer(classOf[Any]) override def deserialize( - jsonParser: JsonParser, - deserializationContext: DeserializationContext): DataType = { + jsonParser: JsonParser, + deserializationContext: DeserializationContext): DataType = { val json = delegate.deserialize(jsonParser, deserializationContext) DataType.parseDataType(json.asInstanceOf[JValue]) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala index 34d19bb67b71a..5eada9a7be670 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala @@ -43,7 +43,8 @@ class Iso8601DateFormatter( locale: Locale, legacyFormat: LegacyDateFormats.LegacyDateFormat, isParsing: Boolean) - extends DateFormatter with DateTimeFormatterHelper { + extends DateFormatter + with DateTimeFormatterHelper { @transient private lazy val formatter = getOrCreateFormatter(pattern, locale, isParsing) @@ -62,8 +63,7 @@ class Iso8601DateFormatter( override def format(localDate: LocalDate): String = { try { localDate.format(formatter) - } catch checkFormattedDiff(toJavaDate(localDateToDays(localDate)), - (d: Date) => format(d)) + } catch checkFormattedDiff(toJavaDate(localDateToDays(localDate)), (d: Date) => format(d)) } override def format(days: Int): String = { @@ -83,19 +83,22 @@ class Iso8601DateFormatter( } /** - * The formatter for dates which doesn't require users to specify a pattern. While formatting, - * it uses the default pattern [[DateFormatter.defaultPattern]]. In parsing, it follows the CAST + * The formatter for dates which doesn't require users to specify a pattern. While formatting, it + * uses the default pattern [[DateFormatter.defaultPattern]]. In parsing, it follows the CAST * logic in conversion of strings to Catalyst's DateType. * - * @param locale The locale overrides the system locale and is used in formatting. - * @param legacyFormat Defines the formatter used for legacy dates. - * @param isParsing Whether the formatter is used for parsing (`true`) or for formatting (`false`). + * @param locale + * The locale overrides the system locale and is used in formatting. + * @param legacyFormat + * Defines the formatter used for legacy dates. + * @param isParsing + * Whether the formatter is used for parsing (`true`) or for formatting (`false`). */ class DefaultDateFormatter( locale: Locale, legacyFormat: LegacyDateFormats.LegacyDateFormat, isParsing: Boolean) - extends Iso8601DateFormatter(DateFormatter.defaultPattern, locale, legacyFormat, isParsing) { + extends Iso8601DateFormatter(DateFormatter.defaultPattern, locale, legacyFormat, isParsing) { override def parse(s: String): Int = { try { @@ -125,11 +128,13 @@ trait LegacyDateFormatter extends DateFormatter { * JVM time zone intentionally for compatibility with Spark 2.4 and earlier versions. * * Note: Using of the default JVM time zone makes the formatter compatible with the legacy - * `SparkDateTimeUtils` methods `toJavaDate` and `fromJavaDate` that are based on the default - * JVM time zone too. + * `SparkDateTimeUtils` methods `toJavaDate` and `fromJavaDate` that are based on the default JVM + * time zone too. * - * @param pattern `java.text.SimpleDateFormat` compatible pattern. - * @param locale The locale overrides the system locale and is used in parsing/formatting. + * @param pattern + * `java.text.SimpleDateFormat` compatible pattern. + * @param locale + * The locale overrides the system locale and is used in parsing/formatting. */ class LegacyFastDateFormatter(pattern: String, locale: Locale) extends LegacyDateFormatter { @transient @@ -145,14 +150,16 @@ class LegacyFastDateFormatter(pattern: String, locale: Locale) extends LegacyDat * JVM time zone intentionally for compatibility with Spark 2.4 and earlier versions. * * Note: Using of the default JVM time zone makes the formatter compatible with the legacy - * `SparkDateTimeUtils` methods `toJavaDate` and `fromJavaDate` that are based on the default - * JVM time zone too. + * `SparkDateTimeUtils` methods `toJavaDate` and `fromJavaDate` that are based on the default JVM + * time zone too. * - * @param pattern The pattern describing the date and time format. - * See - * Date and Time Patterns - * @param locale The locale whose date format symbols should be used. It overrides the system - * locale in parsing/formatting. + * @param pattern + * The pattern describing the date and time format. See Date and + * Time Patterns + * @param locale + * The locale whose date format symbols should be used. It overrides the system locale in + * parsing/formatting. */ // scalastyle:on line.size.limit class LegacySimpleDateFormatter(pattern: String, locale: Locale) extends LegacyDateFormatter { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala index 067e58893126c..71777906f868e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala @@ -40,13 +40,18 @@ trait DateTimeFormatterHelper { } private def verifyLocalDate( - accessor: TemporalAccessor, field: ChronoField, candidate: LocalDate): Unit = { + accessor: TemporalAccessor, + field: ChronoField, + candidate: LocalDate): Unit = { if (accessor.isSupported(field)) { val actual = accessor.get(field) val expected = candidate.get(field) if (actual != expected) { throw ExecutionErrors.fieldDiffersFromDerivedLocalDateError( - field, actual, expected, candidate) + field, + actual, + expected, + candidate) } } } @@ -133,7 +138,8 @@ trait DateTimeFormatterHelper { // SparkUpgradeException. On the contrary, if the legacy policy set to CORRECTED, // DateTimeParseException will address by the caller side. protected def checkParsedDiff[T]( - s: String, legacyParseFunc: String => T): PartialFunction[Throwable, T] = { + s: String, + legacyParseFunc: String => T): PartialFunction[Throwable, T] = { case e if needConvertToSparkUpgradeException(e) => try { legacyParseFunc(s) @@ -151,11 +157,12 @@ trait DateTimeFormatterHelper { d: T, legacyFormatFunc: T => String): PartialFunction[Throwable, String] = { case e if needConvertToSparkUpgradeException(e) => - val resultCandidate = try { - legacyFormatFunc(d) - } catch { - case _: Throwable => throw e - } + val resultCandidate = + try { + legacyFormatFunc(d) + } catch { + case _: Throwable => throw e + } throw ExecutionErrors.failToParseDateTimeInNewParserError(resultCandidate, e) } @@ -166,9 +173,11 @@ trait DateTimeFormatterHelper { * policy or follow our guide to correct their pattern. Otherwise, the original * IllegalArgumentException will be thrown. * - * @param pattern the date time pattern - * @param tryLegacyFormatter a func to capture exception, identically which forces a legacy - * datetime formatter to be initialized + * @param pattern + * the date time pattern + * @param tryLegacyFormatter + * a func to capture exception, identically which forces a legacy datetime formatter to be + * initialized */ protected def checkLegacyFormatter( pattern: String, @@ -214,8 +223,7 @@ private object DateTimeFormatterHelper { /** * Building a formatter for parsing seconds fraction with variable length */ - def createBuilderWithVarLengthSecondFraction( - pattern: String): DateTimeFormatterBuilder = { + def createBuilderWithVarLengthSecondFraction(pattern: String): DateTimeFormatterBuilder = { val builder = createBuilder() pattern.split("'").zipWithIndex.foreach { // Split string starting with the regex itself which is `'` here will produce an extra empty @@ -229,12 +237,14 @@ private object DateTimeFormatterHelper { case extractor(prefix, secondFraction, suffix) => builder.appendPattern(prefix) if (secondFraction.nonEmpty) { - builder.appendFraction(ChronoField.NANO_OF_SECOND, 1, secondFraction.length, false) + builder + .appendFraction(ChronoField.NANO_OF_SECOND, 1, secondFraction.length, false) } rest = suffix - case _ => throw new SparkIllegalArgumentException( - errorClass = "INVALID_DATETIME_PATTERN", - messageParameters = Map("pattern" -> pattern)) + case _ => + throw new SparkIllegalArgumentException( + errorClass = "INVALID_DATETIME_PATTERN.SECONDS_FRACTION", + messageParameters = Map("pattern" -> pattern)) } } case (patternPart, _) => builder.appendLiteral(patternPart) @@ -258,8 +268,10 @@ private object DateTimeFormatterHelper { val builder = createBuilder() .append(DateTimeFormatter.ISO_LOCAL_DATE) .appendLiteral(' ') - .appendValue(ChronoField.HOUR_OF_DAY, 2).appendLiteral(':') - .appendValue(ChronoField.MINUTE_OF_HOUR, 2).appendLiteral(':') + .appendValue(ChronoField.HOUR_OF_DAY, 2) + .appendLiteral(':') + .appendValue(ChronoField.MINUTE_OF_HOUR, 2) + .appendLiteral(':') .appendValue(ChronoField.SECOND_OF_MINUTE, 2) .appendFraction(ChronoField.NANO_OF_SECOND, 0, 9, true) toFormatter(builder, TimestampFormatter.defaultLocale) @@ -299,17 +311,21 @@ private object DateTimeFormatterHelper { * parsing/formatting datetime values. The pattern string is incompatible with the one defined * by SimpleDateFormat in Spark 2.4 and earlier. This function converts all incompatible pattern * for the new parser in Spark 3.0. See more details in SPARK-31030. - * @param pattern The input pattern. - * @return The pattern for new parser + * @param pattern + * The input pattern. + * @return + * The pattern for new parser */ def convertIncompatiblePattern(pattern: String, isParsing: Boolean): String = { - val eraDesignatorContained = pattern.split("'").zipWithIndex.exists { - case (patternPart, index) => + val eraDesignatorContained = + pattern.split("'").zipWithIndex.exists { case (patternPart, index) => // Text can be quoted using single quotes, we only check the non-quote parts. index % 2 == 0 && patternPart.contains("G") - } - (pattern + " ").split("'").zipWithIndex.map { - case (patternPart, index) => + } + (pattern + " ") + .split("'") + .zipWithIndex + .map { case (patternPart, index) => if (index % 2 == 0) { for (c <- patternPart if weekBasedLetters.contains(c)) { throw new SparkIllegalArgumentException( @@ -317,12 +333,10 @@ private object DateTimeFormatterHelper { messageParameters = Map("c" -> c.toString)) } for (c <- patternPart if unsupportedLetters.contains(c) || - (isParsing && unsupportedLettersForParsing.contains(c))) { + (isParsing && unsupportedLettersForParsing.contains(c))) { throw new SparkIllegalArgumentException( errorClass = "INVALID_DATETIME_PATTERN.ILLEGAL_CHARACTER", - messageParameters = Map( - "c" -> c.toString, - "pattern" -> pattern)) + messageParameters = Map("c" -> c.toString, "pattern" -> pattern)) } for (style <- unsupportedPatternLengths if patternPart.contains(style)) { throw new SparkIllegalArgumentException( @@ -340,6 +354,8 @@ private object DateTimeFormatterHelper { } else { patternPart } - }.mkString("'").stripSuffix(" ") + } + .mkString("'") + .stripSuffix(" ") } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala index 96c3fb81aa66f..b113bccc74dfb 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala @@ -89,10 +89,7 @@ object MathUtils { def floorMod(a: Long, b: Long): Long = withOverflow(Math.floorMod(a, b)) - def withOverflow[A]( - f: => A, - hint: String = "", - context: QueryContext = null): A = { + def withOverflow[A](f: => A, hint: String = "", context: QueryContext = null): A = { try { f } catch { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala index f9566b0e1fb13..9c043320dc812 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala @@ -34,8 +34,8 @@ import org.apache.spark.util.SparkClassUtils /** * The collection of functions for rebasing days and microseconds from/to the hybrid calendar - * (Julian + Gregorian since 1582-10-15) which is used by Spark 2.4 and earlier versions - * to/from Proleptic Gregorian calendar which is used by Spark since version 3.0. See SPARK-26651. + * (Julian + Gregorian since 1582-10-15) which is used by Spark 2.4 and earlier versions to/from + * Proleptic Gregorian calendar which is used by Spark since version 3.0. See SPARK-26651. */ object RebaseDateTime { @@ -46,20 +46,22 @@ object RebaseDateTime { } /** - * Rebases days since the epoch from an original to an target calendar, for instance, - * from a hybrid (Julian + Gregorian) to Proleptic Gregorian calendar. + * Rebases days since the epoch from an original to an target calendar, for instance, from a + * hybrid (Julian + Gregorian) to Proleptic Gregorian calendar. * * It finds the latest switch day which is less than the given `days`, and adds the difference - * in days associated with the switch days to the given `days`. - * The function is based on linear search which starts from the most recent switch days. - * This allows to perform less comparisons for modern dates. + * in days associated with the switch days to the given `days`. The function is based on linear + * search which starts from the most recent switch days. This allows to perform less comparisons + * for modern dates. * - * @param switches The days when difference in days between original and target calendar - * was changed. - * @param diffs The differences in days between calendars. - * @param days The number of days since the epoch 1970-01-01 to be rebased - * to the target calendar. - * @return The rebased days. + * @param switches + * The days when difference in days between original and target calendar was changed. + * @param diffs + * The differences in days between calendars. + * @param days + * The number of days since the epoch 1970-01-01 to be rebased to the target calendar. + * @return + * The rebased days. */ private def rebaseDays(switches: Array[Int], diffs: Array[Int], days: Int): Int = { var i = switches.length @@ -77,9 +79,8 @@ object RebaseDateTime { // Julian calendar). This array is not applicable for dates before the staring point. // Rebasing switch days and diffs `julianGregDiffSwitchDay` and `julianGregDiffs` // was generated by the `localRebaseJulianToGregorianDays` function. - private val julianGregDiffSwitchDay = Array( - -719164, -682945, -646420, -609895, -536845, -500320, -463795, - -390745, -354220, -317695, -244645, -208120, -171595, -141427) + private val julianGregDiffSwitchDay = Array(-719164, -682945, -646420, -609895, -536845, + -500320, -463795, -390745, -354220, -317695, -244645, -208120, -171595, -141427) final val lastSwitchJulianDay: Int = julianGregDiffSwitchDay.last @@ -88,20 +89,20 @@ object RebaseDateTime { /** * Converts the given number of days since the epoch day 1970-01-01 to a local date in Julian - * calendar, interprets the result as a local date in Proleptic Gregorian calendar, and takes the - * number of days since the epoch from the Gregorian local date. + * calendar, interprets the result as a local date in Proleptic Gregorian calendar, and takes + * the number of days since the epoch from the Gregorian local date. * * This is used to guarantee backward compatibility, as Spark 2.4 and earlier versions use - * Julian calendar for dates before 1582-10-15, while Spark 3.0 and later use Proleptic Gregorian - * calendar. See SPARK-26651. + * Julian calendar for dates before 1582-10-15, while Spark 3.0 and later use Proleptic + * Gregorian calendar. See SPARK-26651. * - * For example: - * Julian calendar: 1582-01-01 -> -141704 - * Proleptic Gregorian calendar: 1582-01-01 -> -141714 - * The code below converts -141704 to -141714. + * For example: Julian calendar: 1582-01-01 -> -141704 Proleptic Gregorian calendar: 1582-01-01 + * -> -141714 The code below converts -141704 to -141714. * - * @param days The number of days since the epoch in Julian calendar. It can be negative. - * @return The rebased number of days in Gregorian calendar. + * @param days + * The number of days since the epoch in Julian calendar. It can be negative. + * @return + * The rebased number of days in Gregorian calendar. */ private[sql] def localRebaseJulianToGregorianDays(days: Int): Int = { val utcCal = new Calendar.Builder() @@ -111,14 +112,15 @@ object RebaseDateTime { .setTimeZone(TimeZoneUTC) .setInstant(Math.multiplyExact(days, MILLIS_PER_DAY)) .build() - val localDate = LocalDate.of( - utcCal.get(YEAR), - utcCal.get(MONTH) + 1, - // The number of days will be added later to handle non-existing - // Julian dates in Proleptic Gregorian calendar. - // For example, 1000-02-29 exists in Julian calendar because 1000 - // is a leap year but it is not a leap year in Gregorian calendar. - 1) + val localDate = LocalDate + .of( + utcCal.get(YEAR), + utcCal.get(MONTH) + 1, + // The number of days will be added later to handle non-existing + // Julian dates in Proleptic Gregorian calendar. + // For example, 1000-02-29 exists in Julian calendar because 1000 + // is a leap year but it is not a leap year in Gregorian calendar. + 1) .`with`(ChronoField.ERA, utcCal.get(ERA)) .plusDays(utcCal.get(DAY_OF_MONTH) - 1) Math.toIntExact(localDate.toEpochDay) @@ -129,8 +131,10 @@ object RebaseDateTime { * pre-calculated rebasing array to save calculation. For dates of Before Common Era, the * function falls back to the regular unoptimized version. * - * @param days The number of days since the epoch in Julian calendar. It can be negative. - * @return The rebased number of days in Gregorian calendar. + * @param days + * The number of days since the epoch in Julian calendar. It can be negative. + * @return + * The rebased number of days in Gregorian calendar. */ def rebaseJulianToGregorianDays(days: Int): Int = { if (days < julianCommonEraStartDay) { @@ -143,18 +147,17 @@ object RebaseDateTime { // The differences in days between Proleptic Gregorian and Julian dates. // The diff at the index `i` is applicable for all days in the date interval: // [gregJulianDiffSwitchDay(i), gregJulianDiffSwitchDay(i+1)) - private val gregJulianDiffs = Array( - -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) + private val gregJulianDiffs = + Array(-2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) // The sorted days in Proleptic Gregorian calendar when difference in days between // Proleptic Gregorian and Julian was changed. // The starting point is the `0001-01-01` (-719162 days since the epoch in // Proleptic Gregorian calendar). This array is not applicable for dates before the staring point. // Rebasing switch days and diffs `gregJulianDiffSwitchDay` and `gregJulianDiffs` // was generated by the `localRebaseGregorianToJulianDays` function. - private val gregJulianDiffSwitchDay = Array( - -719162, -682944, -646420, -609896, -536847, -500323, -463799, -390750, - -354226, -317702, -244653, -208129, -171605, -141436, -141435, -141434, - -141433, -141432, -141431, -141430, -141429, -141428, -141427) + private val gregJulianDiffSwitchDay = Array(-719162, -682944, -646420, -609896, -536847, + -500323, -463799, -390750, -354226, -317702, -244653, -208129, -171605, -141436, -141435, + -141434, -141433, -141432, -141431, -141430, -141429, -141428, -141427) final val lastSwitchGregorianDay: Int = gregJulianDiffSwitchDay.last @@ -171,17 +174,16 @@ object RebaseDateTime { * number of days since the epoch from the Julian local date. * * This is used to guarantee backward compatibility, as Spark 2.4 and earlier versions use - * Julian calendar for dates before 1582-10-15, while Spark 3.0 and later use Proleptic Gregorian - * calendar. See SPARK-26651. + * Julian calendar for dates before 1582-10-15, while Spark 3.0 and later use Proleptic + * Gregorian calendar. See SPARK-26651. * - * For example: - * Proleptic Gregorian calendar: 1582-01-01 -> -141714 - * Julian calendar: 1582-01-01 -> -141704 - * The code below converts -141714 to -141704. + * For example: Proleptic Gregorian calendar: 1582-01-01 -> -141714 Julian calendar: 1582-01-01 + * -> -141704 The code below converts -141714 to -141704. * - * @param days The number of days since the epoch in Proleptic Gregorian calendar. - * It can be negative. - * @return The rebased number of days in Julian calendar. + * @param days + * The number of days since the epoch in Proleptic Gregorian calendar. It can be negative. + * @return + * The rebased number of days in Julian calendar. */ private[sql] def localRebaseGregorianToJulianDays(days: Int): Int = { var localDate = LocalDate.ofEpochDay(days) @@ -204,8 +206,10 @@ object RebaseDateTime { * pre-calculated rebasing array to save calculation. For dates of Before Common Era, the * function falls back to the regular unoptimized version. * - * @param days The number of days since the epoch in Gregorian calendar. It can be negative. - * @return The rebased number of days since the epoch in Julian calendar. + * @param days + * The number of days since the epoch in Gregorian calendar. It can be negative. + * @return + * The rebased number of days since the epoch in Julian calendar. */ def rebaseGregorianToJulianDays(days: Int): Int = { if (days < gregorianCommonEraStartDay) { @@ -215,10 +219,9 @@ object RebaseDateTime { } } - /** - * The class describes JSON records with microseconds rebasing info. - * Here is an example of JSON file: + * The class describes JSON records with microseconds rebasing info. Here is an example of JSON + * file: * {{{ * [ * { @@ -229,37 +232,44 @@ object RebaseDateTime { * ] * }}} * - * @param tz One of time zone ID which is expected to be acceptable by `ZoneId.of`. - * @param switches An ordered array of seconds since the epoch when the diff between - * two calendars are changed. - * @param diffs Differences in seconds associated with elements of `switches`. + * @param tz + * One of time zone ID which is expected to be acceptable by `ZoneId.of`. + * @param switches + * An ordered array of seconds since the epoch when the diff between two calendars are + * changed. + * @param diffs + * Differences in seconds associated with elements of `switches`. */ private case class JsonRebaseRecord(tz: String, switches: Array[Long], diffs: Array[Long]) /** * Rebasing info used to convert microseconds from an original to a target calendar. * - * @param switches An ordered array of microseconds since the epoch when the diff between - * two calendars are changed. - * @param diffs Differences in microseconds associated with elements of `switches`. + * @param switches + * An ordered array of microseconds since the epoch when the diff between two calendars are + * changed. + * @param diffs + * Differences in microseconds associated with elements of `switches`. */ private[sql] case class RebaseInfo(switches: Array[Long], diffs: Array[Long]) /** - * Rebases micros since the epoch from an original to an target calendar, for instance, - * from a hybrid (Julian + Gregorian) to Proleptic Gregorian calendar. + * Rebases micros since the epoch from an original to an target calendar, for instance, from a + * hybrid (Julian + Gregorian) to Proleptic Gregorian calendar. * * It finds the latest switch micros which is less than the given `micros`, and adds the - * difference in micros associated with the switch micros to the given `micros`. - * The function is based on linear search which starts from the most recent switch micros. - * This allows to perform less comparisons for modern timestamps. + * difference in micros associated with the switch micros to the given `micros`. The function is + * based on linear search which starts from the most recent switch micros. This allows to + * perform less comparisons for modern timestamps. * - * @param rebaseInfo The rebasing info contains an ordered micros when difference in micros - * between original and target calendar was changed, - * and differences in micros between calendars - * @param micros The number of micros since the epoch 1970-01-01T00:00:00Z to be rebased - * to the target calendar. It can be negative. - * @return The rebased micros. + * @param rebaseInfo + * The rebasing info contains an ordered micros when difference in micros between original and + * target calendar was changed, and differences in micros between calendars + * @param micros + * The number of micros since the epoch 1970-01-01T00:00:00Z to be rebased to the target + * calendar. It can be negative. + * @return + * The rebased micros. */ private def rebaseMicros(rebaseInfo: RebaseInfo, micros: Long): Long = { val switches = rebaseInfo.switches @@ -296,18 +306,19 @@ object RebaseDateTime { /** * A map of time zone IDs to its ordered time points (instants in microseconds since the epoch) - * when the difference between 2 instances associated with the same local timestamp in - * Proleptic Gregorian and the hybrid calendar was changed, and to the diff at the index `i` is - * applicable for all microseconds in the time interval: - * [gregJulianDiffSwitchMicros(i), gregJulianDiffSwitchMicros(i+1)) + * when the difference between 2 instances associated with the same local timestamp in Proleptic + * Gregorian and the hybrid calendar was changed, and to the diff at the index `i` is applicable + * for all microseconds in the time interval: [gregJulianDiffSwitchMicros(i), + * gregJulianDiffSwitchMicros(i+1)) */ private val gregJulianRebaseMap = loadRebaseRecords("gregorian-julian-rebase-micros.json") private def getLastSwitchTs(rebaseMap: AnyRefMap[String, RebaseInfo]): Long = { val latestTs = rebaseMap.values.map(_.switches.last).max - require(rebaseMap.values.forall(_.diffs.last == 0), + require( + rebaseMap.values.forall(_.diffs.last == 0), s"Differences between Julian and Gregorian calendar after ${microsToInstant(latestTs)} " + - "are expected to be zero for all available time zones.") + "are expected to be zero for all available time zones.") latestTs } // The switch time point after which all diffs between Gregorian and Julian calendars @@ -315,29 +326,30 @@ object RebaseDateTime { final val lastSwitchGregorianTs: Long = getLastSwitchTs(gregJulianRebaseMap) private final val gregorianStartTs = LocalDateTime.of(gregorianStartDate, LocalTime.MIDNIGHT) - private final val julianEndTs = LocalDateTime.of( - julianEndDate, - LocalTime.of(23, 59, 59, 999999999)) + private final val julianEndTs = + LocalDateTime.of(julianEndDate, LocalTime.of(23, 59, 59, 999999999)) /** * Converts the given number of microseconds since the epoch '1970-01-01T00:00:00Z', to a local - * date-time in Proleptic Gregorian calendar with timezone identified by `zoneId`, interprets the - * result as a local date-time in Julian calendar with the same timezone, and takes microseconds - * since the epoch from the Julian local date-time. + * date-time in Proleptic Gregorian calendar with timezone identified by `zoneId`, interprets + * the result as a local date-time in Julian calendar with the same timezone, and takes + * microseconds since the epoch from the Julian local date-time. * * This is used to guarantee backward compatibility, as Spark 2.4 and earlier versions use - * Julian calendar for dates before 1582-10-15, while Spark 3.0 and later use Proleptic Gregorian - * calendar. See SPARK-26651. + * Julian calendar for dates before 1582-10-15, while Spark 3.0 and later use Proleptic + * Gregorian calendar. See SPARK-26651. * - * For example: - * Proleptic Gregorian calendar: 1582-01-01 00:00:00.123456 -> -12244061221876544 - * Julian calendar: 1582-01-01 00:00:00.123456 -> -12243196799876544 - * The code below converts -12244061221876544 to -12243196799876544. + * For example: Proleptic Gregorian calendar: 1582-01-01 00:00:00.123456 -> -12244061221876544 + * Julian calendar: 1582-01-01 00:00:00.123456 -> -12243196799876544 The code below converts + * -12244061221876544 to -12243196799876544. * - * @param tz The time zone at which the rebasing should be performed. - * @param micros The number of microseconds since the epoch '1970-01-01T00:00:00Z' - * in Proleptic Gregorian calendar. It can be negative. - * @return The rebased microseconds since the epoch in Julian calendar. + * @param tz + * The time zone at which the rebasing should be performed. + * @param micros + * The number of microseconds since the epoch '1970-01-01T00:00:00Z' in Proleptic Gregorian + * calendar. It can be negative. + * @return + * The rebased microseconds since the epoch in Julian calendar. */ private[sql] def rebaseGregorianToJulianMicros(tz: TimeZone, micros: Long): Long = { val instant = microsToInstant(micros) @@ -380,10 +392,13 @@ object RebaseDateTime { * contain information about the given time zone `timeZoneId` or `micros` is related to Before * Common Era, the function falls back to the regular unoptimized version. * - * @param timeZoneId A string identifier of a time zone. - * @param micros The number of microseconds since the epoch '1970-01-01T00:00:00Z' - * in Proleptic Gregorian calendar. It can be negative. - * @return The rebased microseconds since the epoch in Julian calendar. + * @param timeZoneId + * A string identifier of a time zone. + * @param micros + * The number of microseconds since the epoch '1970-01-01T00:00:00Z' in Proleptic Gregorian + * calendar. It can be negative. + * @return + * The rebased microseconds since the epoch in Julian calendar. */ def rebaseGregorianToJulianMicros(timeZoneId: String, micros: Long): Long = { if (micros >= lastSwitchGregorianTs) { @@ -404,12 +419,14 @@ object RebaseDateTime { * contain information about the current JVM system time zone or `micros` is related to Before * Common Era, the function falls back to the regular unoptimized version. * - * Note: The function assumes that the input micros was derived from a local timestamp - * at the default system JVM time zone in Proleptic Gregorian calendar. + * Note: The function assumes that the input micros was derived from a local timestamp at the + * default system JVM time zone in Proleptic Gregorian calendar. * - * @param micros The number of microseconds since the epoch '1970-01-01T00:00:00Z' - * in Proleptic Gregorian calendar. It can be negative. - * @return The rebased microseconds since the epoch in Julian calendar. + * @param micros + * The number of microseconds since the epoch '1970-01-01T00:00:00Z' in Proleptic Gregorian + * calendar. It can be negative. + * @return + * The rebased microseconds since the epoch in Julian calendar. */ def rebaseGregorianToJulianMicros(micros: Long): Long = { rebaseGregorianToJulianMicros(TimeZone.getDefault.getID, micros) @@ -418,22 +435,24 @@ object RebaseDateTime { /** * Converts the given number of microseconds since the epoch '1970-01-01T00:00:00Z', to a local * date-time in Julian calendar with timezone identified by `zoneId`, interprets the result as a - * local date-time in Proleptic Gregorian calendar with the same timezone, and takes microseconds - * since the epoch from the Gregorian local date-time. + * local date-time in Proleptic Gregorian calendar with the same timezone, and takes + * microseconds since the epoch from the Gregorian local date-time. * * This is used to guarantee backward compatibility, as Spark 2.4 and earlier versions use - * Julian calendar for dates before 1582-10-15, while Spark 3.0 and later use Proleptic Gregorian - * calendar. See SPARK-26651. + * Julian calendar for dates before 1582-10-15, while Spark 3.0 and later use Proleptic + * Gregorian calendar. See SPARK-26651. * - * For example: - * Julian calendar: 1582-01-01 00:00:00.123456 -> -12243196799876544 - * Proleptic Gregorian calendar: 1582-01-01 00:00:00.123456 -> -12244061221876544 - * The code below converts -12243196799876544 to -12244061221876544. + * For example: Julian calendar: 1582-01-01 00:00:00.123456 -> -12243196799876544 Proleptic + * Gregorian calendar: 1582-01-01 00:00:00.123456 -> -12244061221876544 The code below converts + * -12243196799876544 to -12244061221876544. * - * @param tz The time zone at which the rebasing should be performed. - * @param micros The number of microseconds since the epoch '1970-01-01T00:00:00Z' - * in the Julian calendar. It can be negative. - * @return The rebased microseconds since the epoch in Proleptic Gregorian calendar. + * @param tz + * The time zone at which the rebasing should be performed. + * @param micros + * The number of microseconds since the epoch '1970-01-01T00:00:00Z' in the Julian calendar. + * It can be negative. + * @return + * The rebased microseconds since the epoch in Proleptic Gregorian calendar. */ private[sql] def rebaseJulianToGregorianMicros(tz: TimeZone, micros: Long): Long = { val cal = new Calendar.Builder() @@ -442,18 +461,19 @@ object RebaseDateTime { .setInstant(microsToMillis(micros)) .setTimeZone(tz) .build() - val localDateTime = LocalDateTime.of( - cal.get(YEAR), - cal.get(MONTH) + 1, - // The number of days will be added later to handle non-existing - // Julian dates in Proleptic Gregorian calendar. - // For example, 1000-02-29 exists in Julian calendar because 1000 - // is a leap year but it is not a leap year in Gregorian calendar. - 1, - cal.get(HOUR_OF_DAY), - cal.get(MINUTE), - cal.get(SECOND), - (Math.floorMod(micros, MICROS_PER_SECOND) * NANOS_PER_MICROS).toInt) + val localDateTime = LocalDateTime + .of( + cal.get(YEAR), + cal.get(MONTH) + 1, + // The number of days will be added later to handle non-existing + // Julian dates in Proleptic Gregorian calendar. + // For example, 1000-02-29 exists in Julian calendar because 1000 + // is a leap year but it is not a leap year in Gregorian calendar. + 1, + cal.get(HOUR_OF_DAY), + cal.get(MINUTE), + cal.get(SECOND), + (Math.floorMod(micros, MICROS_PER_SECOND) * NANOS_PER_MICROS).toInt) .`with`(ChronoField.ERA, cal.get(ERA)) .plusDays(cal.get(DAY_OF_MONTH) - 1) val zoneId = tz.toZoneId @@ -494,10 +514,13 @@ object RebaseDateTime { * contain information about the given time zone `timeZoneId` or `micros` is related to Before * Common Era, the function falls back to the regular unoptimized version. * - * @param timeZoneId A string identifier of a time zone. - * @param micros The number of microseconds since the epoch '1970-01-01T00:00:00Z' - * in the Julian calendar. It can be negative. - * @return The rebased microseconds since the epoch in Proleptic Gregorian calendar. + * @param timeZoneId + * A string identifier of a time zone. + * @param micros + * The number of microseconds since the epoch '1970-01-01T00:00:00Z' in the Julian calendar. + * It can be negative. + * @return + * The rebased microseconds since the epoch in Proleptic Gregorian calendar. */ def rebaseJulianToGregorianMicros(timeZoneId: String, micros: Long): Long = { if (micros >= lastSwitchJulianTs) { @@ -518,12 +541,14 @@ object RebaseDateTime { * contain information about the current JVM system time zone or `micros` is related to Before * Common Era, the function falls back to the regular unoptimized version. * - * Note: The function assumes that the input micros was derived from a local timestamp - * at the default system JVM time zone in the Julian calendar. + * Note: The function assumes that the input micros was derived from a local timestamp at the + * default system JVM time zone in the Julian calendar. * - * @param micros The number of microseconds since the epoch '1970-01-01T00:00:00Z' - * in the Julian calendar. It can be negative. - * @return The rebased microseconds since the epoch in Proleptic Gregorian calendar. + * @param micros + * The number of microseconds since the epoch '1970-01-01T00:00:00Z' in the Julian calendar. + * It can be negative. + * @return + * The rebased microseconds since the epoch in Proleptic Gregorian calendar. */ def rebaseJulianToGregorianMicros(micros: Long): Long = { rebaseJulianToGregorianMicros(TimeZone.getDefault.getID, micros) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkCharVarcharUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkCharVarcharUtils.scala index 498eb83566eb3..2a26c079e8d4d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkCharVarcharUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkCharVarcharUtils.scala @@ -21,15 +21,18 @@ import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{ArrayType, CharType, DataType, MapType, StringType, StructType, VarcharType} trait SparkCharVarcharUtils { + /** - * Returns true if the given data type is CharType/VarcharType or has nested CharType/VarcharType. + * Returns true if the given data type is CharType/VarcharType or has nested + * CharType/VarcharType. */ def hasCharVarchar(dt: DataType): Boolean = { dt.existsRecursively(f => f.isInstanceOf[CharType] || f.isInstanceOf[VarcharType]) } /** - * Validate the given [[DataType]] to fail if it is char or varchar types or contains nested ones + * Validate the given [[DataType]] to fail if it is char or varchar types or contains nested + * ones */ def failIfHasCharVarchar(dt: DataType): DataType = { if (!SqlApiConf.get.charVarcharAsString && hasCharVarchar(dt)) { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala index a6592ad51c65c..4e94bc6617357 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala @@ -54,8 +54,10 @@ trait SparkDateTimeUtils { /** * Converts an Java object to days. * - * @param obj Either an object of `java.sql.Date` or `java.time.LocalDate`. - * @return The number of days since 1970-01-01. + * @param obj + * Either an object of `java.sql.Date` or `java.time.LocalDate`. + * @return + * The number of days since 1970-01-01. */ def anyToDays(obj: Any): Int = obj match { case d: Date => fromJavaDate(d) @@ -65,8 +67,10 @@ trait SparkDateTimeUtils { /** * Converts an Java object to microseconds. * - * @param obj Either an object of `java.sql.Timestamp` or `java.time.{Instant,LocalDateTime}`. - * @return The number of micros since the epoch. + * @param obj + * Either an object of `java.sql.Timestamp` or `java.time.{Instant,LocalDateTime}`. + * @return + * The number of micros since the epoch. */ def anyToMicros(obj: Any): Long = obj match { case t: Timestamp => fromJavaTimestamp(t) @@ -75,8 +79,8 @@ trait SparkDateTimeUtils { } /** - * Converts the timestamp to milliseconds since epoch. In Spark timestamp values have microseconds - * precision, so this conversion is lossy. + * Converts the timestamp to milliseconds since epoch. In Spark timestamp values have + * microseconds precision, so this conversion is lossy. */ def microsToMillis(micros: Long): Long = { // When the timestamp is negative i.e before 1970, we need to adjust the milliseconds portion. @@ -97,8 +101,8 @@ trait SparkDateTimeUtils { private val MIN_SECONDS = Math.floorDiv(Long.MinValue, MICROS_PER_SECOND) /** - * Obtains an instance of `java.time.Instant` using microseconds from - * the epoch of 1970-01-01 00:00:00Z. + * Obtains an instance of `java.time.Instant` using microseconds from the epoch of 1970-01-01 + * 00:00:00Z. */ def microsToInstant(micros: Long): Instant = { val secs = Math.floorDiv(micros, MICROS_PER_SECOND) @@ -110,8 +114,8 @@ trait SparkDateTimeUtils { /** * Gets the number of microseconds since the epoch of 1970-01-01 00:00:00Z from the given - * instance of `java.time.Instant`. The epoch microsecond count is a simple incrementing count of - * microseconds where microsecond 0 is 1970-01-01 00:00:00Z. + * instance of `java.time.Instant`. The epoch microsecond count is a simple incrementing count + * of microseconds where microsecond 0 is 1970-01-01 00:00:00Z. */ def instantToMicros(instant: Instant): Long = { val secs = instant.getEpochSecond @@ -127,8 +131,8 @@ trait SparkDateTimeUtils { /** * Converts the timestamp `micros` from one timezone to another. * - * Time-zone rules, such as daylight savings, mean that not every local date-time - * is valid for the `toZone` time zone, thus the local date-time may be adjusted. + * Time-zone rules, such as daylight savings, mean that not every local date-time is valid for + * the `toZone` time zone, thus the local date-time may be adjusted. */ def convertTz(micros: Long, fromZone: ZoneId, toZone: ZoneId): Long = { val rebasedDateTime = getLocalDateTime(micros, toZone).atZone(fromZone) @@ -160,14 +164,16 @@ trait SparkDateTimeUtils { def daysToLocalDate(days: Int): LocalDate = LocalDate.ofEpochDay(days) /** - * Converts microseconds since 1970-01-01 00:00:00Z to days since 1970-01-01 at the given zone ID. + * Converts microseconds since 1970-01-01 00:00:00Z to days since 1970-01-01 at the given zone + * ID. */ def microsToDays(micros: Long, zoneId: ZoneId): Int = { localDateToDays(getLocalDateTime(micros, zoneId).toLocalDate) } /** - * Converts days since 1970-01-01 at the given zone ID to microseconds since 1970-01-01 00:00:00Z. + * Converts days since 1970-01-01 at the given zone ID to microseconds since 1970-01-01 + * 00:00:00Z. */ def daysToMicros(days: Int, zoneId: ZoneId): Long = { val instant = daysToLocalDate(days).atStartOfDay(zoneId).toInstant @@ -175,20 +181,22 @@ trait SparkDateTimeUtils { } /** - * Converts a local date at the default JVM time zone to the number of days since 1970-01-01 - * in the hybrid calendar (Julian + Gregorian) by discarding the time part. The resulted days are + * Converts a local date at the default JVM time zone to the number of days since 1970-01-01 in + * the hybrid calendar (Julian + Gregorian) by discarding the time part. The resulted days are * rebased from the hybrid to Proleptic Gregorian calendar. The days rebasing is performed via - * UTC time zone for simplicity because the difference between two calendars is the same in - * any given time zone and UTC time zone. + * UTC time zone for simplicity because the difference between two calendars is the same in any + * given time zone and UTC time zone. * - * Note: The date is shifted by the offset of the default JVM time zone for backward compatibility - * with Spark 2.4 and earlier versions. The goal of the shift is to get a local date derived - * from the number of days that has the same date fields (year, month, day) as the original - * `date` at the default JVM time zone. + * Note: The date is shifted by the offset of the default JVM time zone for backward + * compatibility with Spark 2.4 and earlier versions. The goal of the shift is to get a local + * date derived from the number of days that has the same date fields (year, month, day) as the + * original `date` at the default JVM time zone. * - * @param date It represents a specific instant in time based on the hybrid calendar which - * combines Julian and Gregorian calendars. - * @return The number of days since the epoch in Proleptic Gregorian calendar. + * @param date + * It represents a specific instant in time based on the hybrid calendar which combines Julian + * and Gregorian calendars. + * @return + * The number of days since the epoch in Proleptic Gregorian calendar. */ def fromJavaDate(date: Date): Int = { val millisUtc = date.getTime @@ -207,18 +215,20 @@ trait SparkDateTimeUtils { } /** - * Converts days since the epoch 1970-01-01 in Proleptic Gregorian calendar to a local date - * at the default JVM time zone in the hybrid calendar (Julian + Gregorian). It rebases the given + * Converts days since the epoch 1970-01-01 in Proleptic Gregorian calendar to a local date at + * the default JVM time zone in the hybrid calendar (Julian + Gregorian). It rebases the given * days from Proleptic Gregorian to the hybrid calendar at UTC time zone for simplicity because * the difference between two calendars doesn't depend on any time zone. The result is shifted - * by the time zone offset in wall clock to have the same date fields (year, month, day) - * at the default JVM time zone as the input `daysSinceEpoch` in Proleptic Gregorian calendar. + * by the time zone offset in wall clock to have the same date fields (year, month, day) at the + * default JVM time zone as the input `daysSinceEpoch` in Proleptic Gregorian calendar. * - * Note: The date is shifted by the offset of the default JVM time zone for backward compatibility - * with Spark 2.4 and earlier versions. + * Note: The date is shifted by the offset of the default JVM time zone for backward + * compatibility with Spark 2.4 and earlier versions. * - * @param days The number of days since 1970-01-01 in Proleptic Gregorian calendar. - * @return A local date in the hybrid calendar as `java.sql.Date` from number of days since epoch. + * @param days + * The number of days since 1970-01-01 in Proleptic Gregorian calendar. + * @return + * A local date in the hybrid calendar as `java.sql.Date` from number of days since epoch. */ def toJavaDate(days: Int): Date = { val rebasedDays = rebaseGregorianToJulianDays(days) @@ -233,20 +243,22 @@ trait SparkDateTimeUtils { } /** - * Converts microseconds since the epoch to an instance of `java.sql.Timestamp` - * via creating a local timestamp at the system time zone in Proleptic Gregorian - * calendar, extracting date and time fields like `year` and `hours`, and forming - * new timestamp in the hybrid calendar from the extracted fields. + * Converts microseconds since the epoch to an instance of `java.sql.Timestamp` via creating a + * local timestamp at the system time zone in Proleptic Gregorian calendar, extracting date and + * time fields like `year` and `hours`, and forming new timestamp in the hybrid calendar from + * the extracted fields. * - * The conversion is based on the JVM system time zone because the `java.sql.Timestamp` - * uses the time zone internally. + * The conversion is based on the JVM system time zone because the `java.sql.Timestamp` uses the + * time zone internally. * * The method performs the conversion via local timestamp fields to have the same date-time - * representation as `year`, `month`, `day`, ..., `seconds` in the original calendar - * and in the target calendar. + * representation as `year`, `month`, `day`, ..., `seconds` in the original calendar and in the + * target calendar. * - * @param micros The number of microseconds since 1970-01-01T00:00:00.000000Z. - * @return A `java.sql.Timestamp` from number of micros since epoch. + * @param micros + * The number of microseconds since 1970-01-01T00:00:00.000000Z. + * @return + * A `java.sql.Timestamp` from number of micros since epoch. */ def toJavaTimestamp(micros: Long): Timestamp = toJavaTimestampNoRebase(rebaseGregorianToJulianMicros(micros)) @@ -257,8 +269,10 @@ trait SparkDateTimeUtils { /** * Converts microseconds since the epoch to an instance of `java.sql.Timestamp`. * - * @param micros The number of microseconds since 1970-01-01T00:00:00.000000Z. - * @return A `java.sql.Timestamp` from number of micros since epoch. + * @param micros + * The number of microseconds since 1970-01-01T00:00:00.000000Z. + * @return + * A `java.sql.Timestamp` from number of micros since epoch. */ def toJavaTimestampNoRebase(micros: Long): Timestamp = { val seconds = Math.floorDiv(micros, MICROS_PER_SECOND) @@ -270,22 +284,22 @@ trait SparkDateTimeUtils { /** * Converts an instance of `java.sql.Timestamp` to the number of microseconds since - * 1970-01-01T00:00:00.000000Z. It extracts date-time fields from the input, builds - * a local timestamp in Proleptic Gregorian calendar from the fields, and binds - * the timestamp to the system time zone. The resulted instant is converted to - * microseconds since the epoch. + * 1970-01-01T00:00:00.000000Z. It extracts date-time fields from the input, builds a local + * timestamp in Proleptic Gregorian calendar from the fields, and binds the timestamp to the + * system time zone. The resulted instant is converted to microseconds since the epoch. * - * The conversion is performed via the system time zone because it is used internally - * in `java.sql.Timestamp` while extracting date-time fields. + * The conversion is performed via the system time zone because it is used internally in + * `java.sql.Timestamp` while extracting date-time fields. * * The goal of the function is to have the same local date-time in the original calendar - * - the hybrid calendar (Julian + Gregorian) and in the target calendar which is - * Proleptic Gregorian calendar, see SPARK-26651. + * - the hybrid calendar (Julian + Gregorian) and in the target calendar which is Proleptic + * Gregorian calendar, see SPARK-26651. * - * @param t It represents a specific instant in time based on - * the hybrid calendar which combines Julian and - * Gregorian calendars. - * @return The number of micros since epoch from `java.sql.Timestamp`. + * @param t + * It represents a specific instant in time based on the hybrid calendar which combines Julian + * and Gregorian calendars. + * @return + * The number of micros since epoch from `java.sql.Timestamp`. */ def fromJavaTimestamp(t: Timestamp): Long = rebaseJulianToGregorianMicros(fromJavaTimestampNoRebase(t)) @@ -297,30 +311,27 @@ trait SparkDateTimeUtils { * Converts an instance of `java.sql.Timestamp` to the number of microseconds since * 1970-01-01T00:00:00.000000Z. * - * @param t an instance of `java.sql.Timestamp`. - * @return The number of micros since epoch from `java.sql.Timestamp`. + * @param t + * an instance of `java.sql.Timestamp`. + * @return + * The number of micros since epoch from `java.sql.Timestamp`. */ def fromJavaTimestampNoRebase(t: Timestamp): Long = millisToMicros(t.getTime) + (t.getNanos / NANOS_PER_MICROS) % MICROS_PER_MILLIS /** - * Trims and parses a given UTF8 date string to a corresponding [[Int]] value. - * The return type is [[Option]] in order to distinguish between 0 and null. The following - * formats are allowed: + * Trims and parses a given UTF8 date string to a corresponding [[Int]] value. The return type + * is [[Option]] in order to distinguish between 0 and null. The following formats are allowed: * - * `[+-]yyyy*` - * `[+-]yyyy*-[m]m` - * `[+-]yyyy*-[m]m-[d]d` - * `[+-]yyyy*-[m]m-[d]d ` - * `[+-]yyyy*-[m]m-[d]d *` - * `[+-]yyyy*-[m]m-[d]dT*` + * `[+-]yyyy*` `[+-]yyyy*-[m]m` `[+-]yyyy*-[m]m-[d]d` `[+-]yyyy*-[m]m-[d]d ` + * `[+-]yyyy*-[m]m-[d]d *` `[+-]yyyy*-[m]m-[d]dT*` */ def stringToDate(s: UTF8String): Option[Int] = { def isValidDigits(segment: Int, digits: Int): Boolean = { // An integer is able to represent a date within [+-]5 million years. val maxDigitsYear = 7 (segment == 0 && digits >= 4 && digits <= maxDigitsYear) || - (segment != 0 && digits > 0 && digits <= 2) + (segment != 0 && digits > 0 && digits <= 2) } if (s == null) { return None @@ -380,24 +391,18 @@ trait SparkDateTimeUtils { } } - def stringToDateAnsi( - s: UTF8String, - context: QueryContext = null): Int = { + def stringToDateAnsi(s: UTF8String, context: QueryContext = null): Int = { stringToDate(s).getOrElse { throw ExecutionErrors.invalidInputInCastToDatetimeError(s, DateType, context) } } /** - * Trims and parses a given UTF8 timestamp string to the corresponding timestamp segments, - * time zone id and whether it is just time without a date. - * value. The return type is [[Option]] in order to distinguish between 0L and null. The following - * formats are allowed: + * Trims and parses a given UTF8 timestamp string to the corresponding timestamp segments, time + * zone id and whether it is just time without a date. value. The return type is [[Option]] in + * order to distinguish between 0L and null. The following formats are allowed: * - * `[+-]yyyy*` - * `[+-]yyyy*-[m]m` - * `[+-]yyyy*-[m]m-[d]d` - * `[+-]yyyy*-[m]m-[d]d ` + * `[+-]yyyy*` `[+-]yyyy*-[m]m` `[+-]yyyy*-[m]m-[d]d` `[+-]yyyy*-[m]m-[d]d ` * `[+-]yyyy*-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` * `[+-]yyyy*-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` * `[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` @@ -407,16 +412,17 @@ trait SparkDateTimeUtils { * - Z - Zulu time zone UTC+0 * - +|-[h]h:[m]m * - A short id, see https://docs.oracle.com/javase/8/docs/api/java/time/ZoneId.html#SHORT_IDS - * - An id with one of the prefixes UTC+, UTC-, GMT+, GMT-, UT+ or UT-, - * and a suffix in the formats: + * - An id with one of the prefixes UTC+, UTC-, GMT+, GMT-, UT+ or UT-, and a suffix in the + * formats: * - +|-h[h] * - +|-hh[:]mm * - +|-hh:mm:ss * - +|-hhmmss - * - Region-based zone IDs in the form `area/city`, such as `Europe/Paris` + * - Region-based zone IDs in the form `area/city`, such as `Europe/Paris` * - * @return timestamp segments, time zone id and whether the input is just time without a date. If - * the input string can't be parsed as timestamp, the result timestamp segments are empty. + * @return + * timestamp segments, time zone id and whether the input is just time without a date. If the + * input string can't be parsed as timestamp, the result timestamp segments are empty. */ def parseTimestampString(s: UTF8String): (Array[Int], Option[ZoneId], Boolean) = { def isValidDigits(segment: Int, digits: Int): Boolean = { @@ -424,9 +430,9 @@ trait SparkDateTimeUtils { val maxDigitsYear = 6 // For the nanosecond part, more than 6 digits is allowed, but will be truncated. segment == 6 || (segment == 0 && digits >= 4 && digits <= maxDigitsYear) || - // For the zoneId segment(7), it's could be zero digits when it's a region-based zone ID - (segment == 7 && digits <= 2) || - (segment != 0 && segment != 6 && segment != 7 && digits > 0 && digits <= 2) + // For the zoneId segment(7), it's could be zero digits when it's a region-based zone ID + (segment == 7 && digits <= 2) || + (segment != 0 && segment != 6 && segment != 7 && digits > 0 && digits <= 2) } if (s == null) { return (Array.empty, None, false) @@ -523,7 +529,7 @@ trait SparkDateTimeUtils { tz = Some(new String(bytes, j, strEndTrimmed - j)) j = strEndTrimmed - 1 } - if (i == 6 && b != '.') { + if (i == 6 && b != '.') { i += 1 } } else { @@ -612,11 +618,11 @@ trait SparkDateTimeUtils { * * If the input string contains a component associated with time zone, the method will return * `None` if `allowTimeZone` is set to `false`. If `allowTimeZone` is set to `true`, the method - * will simply discard the time zone component. Enable the check to detect situations like parsing - * a timestamp with time zone as TimestampNTZType. + * will simply discard the time zone component. Enable the check to detect situations like + * parsing a timestamp with time zone as TimestampNTZType. * - * The return type is [[Option]] in order to distinguish between 0L and null. Please - * refer to `parseTimestampString` for the allowed formats. + * The return type is [[Option]] in order to distinguish between 0L and null. Please refer to + * `parseTimestampString` for the allowed formats. */ def stringToTimestampWithoutTimeZone(s: UTF8String, allowTimeZone: Boolean): Option[Long] = { try { @@ -637,10 +643,13 @@ trait SparkDateTimeUtils { } /** - * Returns the index of the first non-whitespace and non-ISO control character in the byte array. + * Returns the index of the first non-whitespace and non-ISO control character in the byte + * array. * - * @param bytes The byte array to be processed. - * @return The start index after trimming. + * @param bytes + * The byte array to be processed. + * @return + * The start index after trimming. */ @inline private def getTrimmedStart(bytes: Array[Byte]) = { var start = 0 @@ -655,9 +664,12 @@ trait SparkDateTimeUtils { /** * Returns the index of the last non-whitespace and non-ISO control character in the byte array. * - * @param start The starting index for the search. - * @param bytes The byte array to be processed. - * @return The end index after trimming. + * @param start + * The starting index for the search. + * @param bytes + * The byte array to be processed. + * @return + * The end index after trimming. */ @inline private def getTrimmedEnd(start: Int, bytes: Array[Byte]) = { var end = bytes.length - 1 diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkIntervalUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkIntervalUtils.scala index 5e236187a4a0f..b8387b78ae3e2 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkIntervalUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkIntervalUtils.scala @@ -38,17 +38,16 @@ trait SparkIntervalUtils { private final val minDurationSeconds = Math.floorDiv(Long.MinValue, MICROS_PER_SECOND) /** - * Converts this duration to the total length in microseconds. - *

    - * If this duration is too large to fit in a [[Long]] microseconds, then an - * exception is thrown. - *

    - * If this duration has greater than microsecond precision, then the conversion - * will drop any excess precision information as though the amount in nanoseconds - * was subject to integer division by one thousand. + * Converts this duration to the total length in microseconds.

    If this duration is too large + * to fit in a [[Long]] microseconds, then an exception is thrown.

    If this duration has + * greater than microsecond precision, then the conversion will drop any excess precision + * information as though the amount in nanoseconds was subject to integer division by one + * thousand. * - * @return The total length of the duration in microseconds - * @throws ArithmeticException If numeric overflow occurs + * @return + * The total length of the duration in microseconds + * @throws ArithmeticException + * If numeric overflow occurs */ def durationToMicros(duration: Duration): Long = { durationToMicros(duration, DT.SECOND) @@ -59,7 +58,8 @@ trait SparkIntervalUtils { val micros = if (seconds == minDurationSeconds) { val microsInSeconds = (minDurationSeconds + 1) * MICROS_PER_SECOND val nanoAdjustment = duration.getNano - assert(0 <= nanoAdjustment && nanoAdjustment < NANOS_PER_SECOND, + assert( + 0 <= nanoAdjustment && nanoAdjustment < NANOS_PER_SECOND, "Duration.getNano() must return the adjustment to the seconds field " + "in the range from 0 to 999999999 nanoseconds, inclusive.") Math.addExact(microsInSeconds, (nanoAdjustment - NANOS_PER_SECOND) / NANOS_PER_MICROS) @@ -77,14 +77,13 @@ trait SparkIntervalUtils { } /** - * Gets the total number of months in this period. - *

    - * This returns the total number of months in the period by multiplying the - * number of years by 12 and adding the number of months. - *

    + * Gets the total number of months in this period.

    This returns the total number of months + * in the period by multiplying the number of years by 12 and adding the number of months.

    * - * @return The total number of months in the period, may be negative - * @throws ArithmeticException If numeric overflow occurs + * @return + * The total number of months in the period, may be negative + * @throws ArithmeticException + * If numeric overflow occurs */ def periodToMonths(period: Period): Int = { periodToMonths(period, YM.MONTH) @@ -103,39 +102,41 @@ trait SparkIntervalUtils { /** * Obtains a [[Duration]] representing a number of microseconds. * - * @param micros The number of microseconds, positive or negative - * @return A [[Duration]], not null + * @param micros + * The number of microseconds, positive or negative + * @return + * A [[Duration]], not null */ def microsToDuration(micros: Long): Duration = Duration.of(micros, ChronoUnit.MICROS) /** - * Obtains a [[Period]] representing a number of months. The days unit will be zero, and the years - * and months units will be normalized. + * Obtains a [[Period]] representing a number of months. The days unit will be zero, and the + * years and months units will be normalized. * - *

    - * The months unit is adjusted to have an absolute value < 12, with the years unit being adjusted - * to compensate. For example, the method returns "2 years and 3 months" for the 27 input months. - *

    - * The sign of the years and months units will be the same after normalization. - * For example, -13 months will be converted to "-1 year and -1 month". + *

    The months unit is adjusted to have an absolute value < 12, with the years unit being + * adjusted to compensate. For example, the method returns "2 years and 3 months" for the 27 + * input months.

    The sign of the years and months units will be the same after + * normalization. For example, -13 months will be converted to "-1 year and -1 month". * - * @param months The number of months, positive or negative - * @return The period of months, not null + * @param months + * The number of months, positive or negative + * @return + * The period of months, not null */ def monthsToPeriod(months: Int): Period = Period.ofMonths(months).normalized() /** * Converts a string to [[CalendarInterval]] case-insensitively. * - * @throws IllegalArgumentException if the input string is not in valid interval format. + * @throws IllegalArgumentException + * if the input string is not in valid interval format. */ def stringToInterval(input: UTF8String): CalendarInterval = { import ParseState._ if (input == null) { throw new SparkIllegalArgumentException( errorClass = "INVALID_INTERVAL_FORMAT.INPUT_IS_NULL", - messageParameters = Map( - "input" -> "null")) + messageParameters = Map("input" -> "null")) } // scalastyle:off caselocale .toLowerCase val s = input.trimAll().toLowerCase @@ -144,8 +145,7 @@ trait SparkIntervalUtils { if (bytes.isEmpty) { throw new SparkIllegalArgumentException( errorClass = "INVALID_INTERVAL_FORMAT.INPUT_IS_EMPTY", - messageParameters = Map( - "input" -> input.toString)) + messageParameters = Map("input" -> input.toString)) } var state = PREFIX var i = 0 @@ -182,14 +182,11 @@ trait SparkIntervalUtils { if (s.numBytes() == intervalStr.numBytes()) { throw new SparkIllegalArgumentException( errorClass = "INVALID_INTERVAL_FORMAT.INPUT_IS_EMPTY", - messageParameters = Map( - "input" -> input.toString)) + messageParameters = Map("input" -> input.toString)) } else if (!Character.isWhitespace(bytes(i + intervalStr.numBytes()))) { throw new SparkIllegalArgumentException( errorClass = "INVALID_INTERVAL_FORMAT.INVALID_PREFIX", - messageParameters = Map( - "input" -> input.toString, - "prefix" -> currentWord)) + messageParameters = Map("input" -> input.toString, "prefix" -> currentWord)) } else { i += intervalStr.numBytes() + 1 } @@ -224,11 +221,10 @@ trait SparkIntervalUtils { pointPrefixed = true i += 1 state = VALUE_FRACTIONAL_PART - case _ => throw new SparkIllegalArgumentException( - errorClass = "INVALID_INTERVAL_FORMAT.UNRECOGNIZED_NUMBER", - messageParameters = Map( - "input" -> input.toString, - "number" -> currentWord)) + case _ => + throw new SparkIllegalArgumentException( + errorClass = "INVALID_INTERVAL_FORMAT.UNRECOGNIZED_NUMBER", + messageParameters = Map("input" -> input.toString, "number" -> currentWord)) } case TRIM_BEFORE_VALUE => trimToNextState(b, VALUE) case VALUE => @@ -237,20 +233,19 @@ trait SparkIntervalUtils { try { currentValue = Math.addExact(Math.multiplyExact(10, currentValue), (b - '0')) } catch { - case e: ArithmeticException => throw new SparkIllegalArgumentException( - errorClass = "INVALID_INTERVAL_FORMAT.ARITHMETIC_EXCEPTION", - messageParameters = Map( - "input" -> input.toString)) + case e: ArithmeticException => + throw new SparkIllegalArgumentException( + errorClass = "INVALID_INTERVAL_FORMAT.ARITHMETIC_EXCEPTION", + messageParameters = Map("input" -> input.toString)) } case _ if Character.isWhitespace(b) => state = TRIM_BEFORE_UNIT case '.' => fractionScale = initialFractionScale state = VALUE_FRACTIONAL_PART - case _ => throw new SparkIllegalArgumentException( - errorClass = "INVALID_INTERVAL_FORMAT.INVALID_VALUE", - messageParameters = Map( - "input" -> input.toString, - "value" -> currentWord)) + case _ => + throw new SparkIllegalArgumentException( + errorClass = "INVALID_INTERVAL_FORMAT.INVALID_VALUE", + messageParameters = Map("input" -> input.toString, "value" -> currentWord)) } i += 1 case VALUE_FRACTIONAL_PART => @@ -264,15 +259,11 @@ trait SparkIntervalUtils { } else if ('0' <= b && b <= '9') { throw new SparkIllegalArgumentException( errorClass = "INVALID_INTERVAL_FORMAT.INVALID_PRECISION", - messageParameters = Map( - "input" -> input.toString, - "value" -> currentWord)) + messageParameters = Map("input" -> input.toString, "value" -> currentWord)) } else { throw new SparkIllegalArgumentException( errorClass = "INVALID_INTERVAL_FORMAT.INVALID_VALUE", - messageParameters = Map( - "input" -> input.toString, - "value" -> currentWord)) + messageParameters = Map("input" -> input.toString, "value" -> currentWord)) } i += 1 case TRIM_BEFORE_UNIT => trimToNextState(b, UNIT_BEGIN) @@ -281,9 +272,7 @@ trait SparkIntervalUtils { if (b != 's' && fractionScale >= 0) { throw new SparkIllegalArgumentException( errorClass = "INVALID_INTERVAL_FORMAT.INVALID_FRACTION", - messageParameters = Map( - "input" -> input.toString, - "unit" -> currentWord)) + messageParameters = Map("input" -> input.toString, "unit" -> currentWord)) } if (isNegative) { currentValue = -currentValue @@ -328,44 +317,38 @@ trait SparkIntervalUtils { } else { throw new SparkIllegalArgumentException( errorClass = "INVALID_INTERVAL_FORMAT.INVALID_UNIT", - messageParameters = Map( - "input" -> input.toString, - "unit" -> currentWord)) + messageParameters = Map("input" -> input.toString, "unit" -> currentWord)) } - case _ => throw new SparkIllegalArgumentException( - errorClass = "INVALID_INTERVAL_FORMAT.INVALID_UNIT", - messageParameters = Map( - "input" -> input.toString, - "unit" -> currentWord)) + case _ => + throw new SparkIllegalArgumentException( + errorClass = "INVALID_INTERVAL_FORMAT.INVALID_UNIT", + messageParameters = Map("input" -> input.toString, "unit" -> currentWord)) } } catch { - case e: ArithmeticException => throw new SparkIllegalArgumentException( - errorClass = "INVALID_INTERVAL_FORMAT.ARITHMETIC_EXCEPTION", - messageParameters = Map( - "input" -> input.toString)) + case e: ArithmeticException => + throw new SparkIllegalArgumentException( + errorClass = "INVALID_INTERVAL_FORMAT.ARITHMETIC_EXCEPTION", + messageParameters = Map("input" -> input.toString)) } state = UNIT_SUFFIX case UNIT_SUFFIX => b match { case 's' => state = UNIT_END case _ if Character.isWhitespace(b) => state = TRIM_BEFORE_SIGN - case _ => throw new SparkIllegalArgumentException( - errorClass = "INVALID_INTERVAL_FORMAT.INVALID_UNIT", - messageParameters = Map( - "input" -> input.toString, - "unit" -> currentWord)) + case _ => + throw new SparkIllegalArgumentException( + errorClass = "INVALID_INTERVAL_FORMAT.INVALID_UNIT", + messageParameters = Map("input" -> input.toString, "unit" -> currentWord)) } i += 1 case UNIT_END => - if (Character.isWhitespace(b) ) { + if (Character.isWhitespace(b)) { i += 1 state = TRIM_BEFORE_SIGN } else { throw new SparkIllegalArgumentException( errorClass = "INVALID_INTERVAL_FORMAT.INVALID_UNIT", - messageParameters = Map( - "input" -> input.toString, - "unit" -> currentWord)) + messageParameters = Map("input" -> input.toString, "unit" -> currentWord)) } } } @@ -373,36 +356,37 @@ trait SparkIntervalUtils { val result = state match { case UNIT_SUFFIX | UNIT_END | TRIM_BEFORE_SIGN => new CalendarInterval(months, days, microseconds) - case TRIM_BEFORE_VALUE => throw new SparkIllegalArgumentException( - errorClass = "INVALID_INTERVAL_FORMAT.MISSING_NUMBER", - messageParameters = Map( - "input" -> input.toString, - "word" -> currentWord)) + case TRIM_BEFORE_VALUE => + throw new SparkIllegalArgumentException( + errorClass = "INVALID_INTERVAL_FORMAT.MISSING_NUMBER", + messageParameters = Map("input" -> input.toString, "word" -> currentWord)) case VALUE | VALUE_FRACTIONAL_PART => throw new SparkIllegalArgumentException( errorClass = "INVALID_INTERVAL_FORMAT.MISSING_UNIT", - messageParameters = Map( - "input" -> input.toString, - "word" -> currentWord)) - case _ => throw new SparkIllegalArgumentException( - errorClass = "INVALID_INTERVAL_FORMAT.UNKNOWN_PARSING_ERROR", - messageParameters = Map( - "input" -> input.toString, - "word" -> currentWord)) + messageParameters = Map("input" -> input.toString, "word" -> currentWord)) + case _ => + throw new SparkIllegalArgumentException( + errorClass = "INVALID_INTERVAL_FORMAT.UNKNOWN_PARSING_ERROR", + messageParameters = Map("input" -> input.toString, "word" -> currentWord)) } result } /** - * Converts an year-month interval as a number of months to its textual representation - * which conforms to the ANSI SQL standard. + * Converts an year-month interval as a number of months to its textual representation which + * conforms to the ANSI SQL standard. * - * @param months The number of months, positive or negative - * @param style The style of textual representation of the interval - * @param startField The start field (YEAR or MONTH) which the interval comprises of. - * @param endField The end field (YEAR or MONTH) which the interval comprises of. - * @return Year-month interval string + * @param months + * The number of months, positive or negative + * @param style + * The style of textual representation of the interval + * @param startField + * The start field (YEAR or MONTH) which the interval comprises of. + * @param endField + * The end field (YEAR or MONTH) which the interval comprises of. + * @return + * Year-month interval string */ def toYearMonthIntervalString( months: Int, @@ -434,14 +418,19 @@ trait SparkIntervalUtils { } /** - * Converts a day-time interval as a number of microseconds to its textual representation - * which conforms to the ANSI SQL standard. + * Converts a day-time interval as a number of microseconds to its textual representation which + * conforms to the ANSI SQL standard. * - * @param micros The number of microseconds, positive or negative - * @param style The style of textual representation of the interval - * @param startField The start field (DAY, HOUR, MINUTE, SECOND) which the interval comprises of. - * @param endField The end field (DAY, HOUR, MINUTE, SECOND) which the interval comprises of. - * @return Day-time interval string + * @param micros + * The number of microseconds, positive or negative + * @param style + * The style of textual representation of the interval + * @param startField + * The start field (DAY, HOUR, MINUTE, SECOND) which the interval comprises of. + * @param endField + * The end field (DAY, HOUR, MINUTE, SECOND) which the interval comprises of. + * @return + * Day-time interval string */ def toDayTimeIntervalString( micros: Long, @@ -514,8 +503,9 @@ trait SparkIntervalUtils { rest %= MICROS_PER_MINUTE case DT.SECOND => val leadZero = if (rest < 10 * MICROS_PER_SECOND) "0" else "" - formatBuilder.append(s"$leadZero" + - s"${java.math.BigDecimal.valueOf(rest, 6).stripTrailingZeros.toPlainString}") + formatBuilder.append( + s"$leadZero" + + s"${java.math.BigDecimal.valueOf(rest, 6).stripTrailingZeros.toPlainString}") } if (startField < DT.HOUR && DT.HOUR <= endField) { @@ -565,20 +555,11 @@ trait SparkIntervalUtils { protected val microsStr: UTF8String = unitToUtf8("microsecond") protected val nanosStr: UTF8String = unitToUtf8("nanosecond") - private object ParseState extends Enumeration { type ParseState = Value - val PREFIX, - TRIM_BEFORE_SIGN, - SIGN, - TRIM_BEFORE_VALUE, - VALUE, - VALUE_FRACTIONAL_PART, - TRIM_BEFORE_UNIT, - UNIT_BEGIN, - UNIT_SUFFIX, - UNIT_END = Value + val PREFIX, TRIM_BEFORE_SIGN, SIGN, TRIM_BEFORE_VALUE, VALUE, VALUE_FRACTIONAL_PART, + TRIM_BEFORE_UNIT, UNIT_BEGIN, UNIT_SUFFIX, UNIT_END = Value } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala index 7597cb1d9087d..01ee899085701 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala @@ -62,8 +62,8 @@ trait SparkParserUtils { val secondChar = s.charAt(start + 1) val thirdChar = s.charAt(start + 2) (firstChar == '0' || firstChar == '1') && - (secondChar >= '0' && secondChar <= '7') && - (thirdChar >= '0' && thirdChar <= '7') + (secondChar >= '0' && secondChar <= '7') && + (thirdChar >= '0' && thirdChar <= '7') } val isRawString = { @@ -97,15 +97,18 @@ trait SparkParserUtils { // \u0000 style 16-bit unicode character literals. sb.append(Integer.parseInt(b, i + 1, i + 1 + 4, 16).toChar) i += 1 + 4 - } else if (cAfterBackslash == 'U' && i + 1 + 8 <= length && allCharsAreHex(b, i + 1, 8)) { + } else if (cAfterBackslash == 'U' && i + 1 + 8 <= length && allCharsAreHex( + b, + i + 1, + 8)) { // \U00000000 style 32-bit unicode character literals. // Use Long to treat codePoint as unsigned in the range of 32-bit. val codePoint = JLong.parseLong(b, i + 1, i + 1 + 8, 16) if (codePoint < 0x10000) { - sb.append((codePoint & 0xFFFF).toChar) + sb.append((codePoint & 0xffff).toChar) } else { - val highSurrogate = (codePoint - 0x10000) / 0x400 + 0xD800 - val lowSurrogate = (codePoint - 0x10000) % 0x400 + 0xDC00 + val highSurrogate = (codePoint - 0x10000) / 0x400 + 0xd800 + val lowSurrogate = (codePoint - 0x10000) % 0x400 + 0xdc00 sb.append(highSurrogate.toChar) sb.append(lowSurrogate.toChar) } @@ -147,8 +150,13 @@ trait SparkParserUtils { if (text.isEmpty) { CurrentOrigin.set(position(ctx.getStart)) } else { - CurrentOrigin.set(positionAndText(ctx.getStart, ctx.getStop, text.get, - current.objectType, current.objectName)) + CurrentOrigin.set( + positionAndText( + ctx.getStart, + ctx.getStop, + text.get, + current.objectType, + current.objectName)) } try { f diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index edb1ee371b156..0608322be13b3 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -24,9 +24,8 @@ import org.apache.spark.unsafe.array.ByteArrayUtils import org.apache.spark.util.ArrayImplicits._ /** - * Concatenation of sequence of strings to final string with cheap append method - * and one memory allocation for the final string. Can also bound the final size of - * the string. + * Concatenation of sequence of strings to final string with cheap append method and one memory + * allocation for the final string. Can also bound the final size of the string. */ class StringConcat(val maxLength: Int = ByteArrayUtils.MAX_ROUNDED_ARRAY_LENGTH) { protected val strings = new java.util.ArrayList[String] @@ -35,9 +34,9 @@ class StringConcat(val maxLength: Int = ByteArrayUtils.MAX_ROUNDED_ARRAY_LENGTH) def atLimit: Boolean = length >= maxLength /** - * Appends a string and accumulates its length to allocate a string buffer for all - * appended strings once in the toString method. Returns true if the string still - * has room for further appends before it hits its max limit. + * Appends a string and accumulates its length to allocate a string buffer for all appended + * strings once in the toString method. Returns true if the string still has room for further + * appends before it hits its max limit. */ def append(s: String): Unit = { if (s != null) { @@ -56,8 +55,8 @@ class StringConcat(val maxLength: Int = ByteArrayUtils.MAX_ROUNDED_ARRAY_LENGTH) } /** - * The method allocates memory for all appended strings, writes them to the memory and - * returns concatenated string. + * The method allocates memory for all appended strings, writes them to the memory and returns + * concatenated string. */ override def toString: String = { val finalLength = if (atLimit) maxLength else length @@ -68,6 +67,7 @@ class StringConcat(val maxLength: Int = ByteArrayUtils.MAX_ROUNDED_ARRAY_LENGTH) } object SparkStringUtils extends Logging { + /** Whether we have warned about plan string truncation yet. */ private val truncationWarningPrinted = new AtomicBoolean(false) @@ -75,7 +75,8 @@ object SparkStringUtils extends Logging { * Format a sequence with semantics similar to calling .mkString(). Any elements beyond * maxNumToStringFields will be dropped and replaced by a "... N more fields" placeholder. * - * @return the trimmed and formatted string. + * @return + * the trimmed and formatted string. */ def truncatedString[T]( seq: Seq[T], @@ -90,8 +91,9 @@ object SparkStringUtils extends Logging { s"behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.") } val numFields = math.max(0, maxFields - 1) - seq.take(numFields).mkString( - start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end) + seq + .take(numFields) + .mkString(start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end) } else { seq.mkString(start, sep, end) } @@ -106,8 +108,8 @@ object SparkStringUtils extends Logging { HexFormat.of().withDelimiter(" ").withUpperCase() /** - * Returns a pretty string of the byte array which prints each byte as a hex digit and add spaces - * between them. For example, [1A C0]. + * Returns a pretty string of the byte array which prints each byte as a hex digit and add + * spaces between them. For example, [1A C0]. */ def getHexString(bytes: Array[Byte]): String = { s"[${SPACE_DELIMITED_UPPERCASE_HEX.formatHex(bytes)}]" @@ -122,8 +124,8 @@ object SparkStringUtils extends Logging { val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("") val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("") - leftPadded.zip(rightPadded).map { - case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.length) + 3)) + r + leftPadded.zip(rightPadded).map { case (l, r) => + (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.length) + 3)) + r } } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala index 79d627b493fd8..4fcb84daf187d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala @@ -41,14 +41,20 @@ import org.apache.spark.sql.types.{Decimal, TimestampNTZType} import org.apache.spark.unsafe.types.UTF8String sealed trait TimestampFormatter extends Serializable { + /** * Parses a timestamp in a string and converts it to microseconds. * - * @param s - string with timestamp to parse - * @return microseconds since epoch. - * @throws ParseException can be thrown by legacy parser - * @throws DateTimeParseException can be thrown by new parser - * @throws DateTimeException unable to obtain local date or time + * @param s + * \- string with timestamp to parse + * @return + * microseconds since epoch. + * @throws ParseException + * can be thrown by legacy parser + * @throws DateTimeParseException + * can be thrown by new parser + * @throws DateTimeException + * unable to obtain local date or time */ @throws(classOf[ParseException]) @throws(classOf[DateTimeParseException]) @@ -58,11 +64,16 @@ sealed trait TimestampFormatter extends Serializable { /** * Parses a timestamp in a string and converts it to an optional number of microseconds. * - * @param s - string with timestamp to parse - * @return An optional number of microseconds since epoch. The result is None on invalid input. - * @throws ParseException can be thrown by legacy parser - * @throws DateTimeParseException can be thrown by new parser - * @throws DateTimeException unable to obtain local date or time + * @param s + * \- string with timestamp to parse + * @return + * An optional number of microseconds since epoch. The result is None on invalid input. + * @throws ParseException + * can be thrown by legacy parser + * @throws DateTimeParseException + * can be thrown by new parser + * @throws DateTimeException + * unable to obtain local date or time */ @throws(classOf[ParseException]) @throws(classOf[DateTimeParseException]) @@ -75,16 +86,24 @@ sealed trait TimestampFormatter extends Serializable { } /** - * Parses a timestamp in a string and converts it to microseconds since Unix Epoch in local time. + * Parses a timestamp in a string and converts it to microseconds since Unix Epoch in local + * time. * - * @param s - string with timestamp to parse - * @param allowTimeZone - indicates strict parsing of timezone - * @return microseconds since epoch. - * @throws ParseException can be thrown by legacy parser - * @throws DateTimeParseException can be thrown by new parser - * @throws DateTimeException unable to obtain local date or time - * @throws IllegalStateException The formatter for timestamp without time zone should always - * implement this method. The exception should never be hit. + * @param s + * \- string with timestamp to parse + * @param allowTimeZone + * \- indicates strict parsing of timezone + * @return + * microseconds since epoch. + * @throws ParseException + * can be thrown by legacy parser + * @throws DateTimeParseException + * can be thrown by new parser + * @throws DateTimeException + * unable to obtain local date or time + * @throws IllegalStateException + * The formatter for timestamp without time zone should always implement this method. The + * exception should never be hit. */ @throws(classOf[ParseException]) @throws(classOf[DateTimeParseException]) @@ -99,14 +118,21 @@ sealed trait TimestampFormatter extends Serializable { * Parses a timestamp in a string and converts it to an optional number of microseconds since * Unix Epoch in local time. * - * @param s - string with timestamp to parse - * @param allowTimeZone - indicates strict parsing of timezone - * @return An optional number of microseconds since epoch. The result is None on invalid input. - * @throws ParseException can be thrown by legacy parser - * @throws DateTimeParseException can be thrown by new parser - * @throws DateTimeException unable to obtain local date or time - * @throws IllegalStateException The formatter for timestamp without time zone should always - * implement this method. The exception should never be hit. + * @param s + * \- string with timestamp to parse + * @param allowTimeZone + * \- indicates strict parsing of timezone + * @return + * An optional number of microseconds since epoch. The result is None on invalid input. + * @throws ParseException + * can be thrown by legacy parser + * @throws DateTimeParseException + * can be thrown by new parser + * @throws DateTimeException + * unable to obtain local date or time + * @throws IllegalStateException + * The formatter for timestamp without time zone should always implement this method. The + * exception should never be hit. */ @throws(classOf[ParseException]) @throws(classOf[DateTimeParseException]) @@ -120,8 +146,8 @@ sealed trait TimestampFormatter extends Serializable { } /** - * Parses a timestamp in a string and converts it to microseconds since Unix Epoch in local time. - * Zone-id and zone-offset components are ignored. + * Parses a timestamp in a string and converts it to microseconds since Unix Epoch in local + * time. Zone-id and zone-offset components are ignored. */ @throws(classOf[ParseException]) @throws(classOf[DateTimeParseException]) @@ -144,9 +170,10 @@ sealed trait TimestampFormatter extends Serializable { /** * Validates the pattern string. - * @param checkLegacy if true and the pattern is invalid, check whether the pattern is valid for - * legacy formatters and show hints for using legacy formatter. - * Otherwise, simply check the pattern string. + * @param checkLegacy + * if true and the pattern is invalid, check whether the pattern is valid for legacy + * formatters and show hints for using legacy formatter. Otherwise, simply check the pattern + * string. */ def validatePatternString(checkLegacy: Boolean): Unit } @@ -157,7 +184,8 @@ class Iso8601TimestampFormatter( locale: Locale, legacyFormat: LegacyDateFormat = LENIENT_SIMPLE_DATE_FORMAT, isParsing: Boolean) - extends TimestampFormatter with DateTimeFormatterHelper { + extends TimestampFormatter + with DateTimeFormatterHelper { @transient protected lazy val formatter: DateTimeFormatter = getOrCreateFormatter(pattern, locale, isParsing) @@ -166,8 +194,8 @@ class Iso8601TimestampFormatter( private lazy val zonedFormatter: DateTimeFormatter = formatter.withZone(zoneId) @transient - protected lazy val legacyFormatter = TimestampFormatter.getLegacyFormatter( - pattern, zoneId, locale, legacyFormat) + protected lazy val legacyFormatter = + TimestampFormatter.getLegacyFormatter(pattern, zoneId, locale, legacyFormat) override def parseOptional(s: String): Option[Long] = { try { @@ -235,8 +263,8 @@ class Iso8601TimestampFormatter( override def format(instant: Instant): String = { try { zonedFormatter.format(instant) - } catch checkFormattedDiff(toJavaTimestamp(instantToMicros(instant)), - (t: Timestamp) => format(t)) + } catch + checkFormattedDiff(toJavaTimestamp(instantToMicros(instant)), (t: Timestamp) => format(t)) } override def format(us: Long): String = { @@ -256,8 +284,8 @@ class Iso8601TimestampFormatter( if (checkLegacy) { try { formatter - } catch checkLegacyFormatter(pattern, - legacyFormatter.validatePatternString(checkLegacy = true)) + } catch + checkLegacyFormatter(pattern, legacyFormatter.validatePatternString(checkLegacy = true)) () } else { try { @@ -268,22 +296,30 @@ class Iso8601TimestampFormatter( } /** - * The formatter for timestamps which doesn't require users to specify a pattern. While formatting, - * it uses the default pattern [[TimestampFormatter.defaultPattern()]]. In parsing, it follows - * the CAST logic in conversion of strings to Catalyst's TimestampType. + * The formatter for timestamps which doesn't require users to specify a pattern. While + * formatting, it uses the default pattern [[TimestampFormatter.defaultPattern()]]. In parsing, it + * follows the CAST logic in conversion of strings to Catalyst's TimestampType. * - * @param zoneId The time zone ID in which timestamps should be formatted or parsed. - * @param locale The locale overrides the system locale and is used in formatting. - * @param legacyFormat Defines the formatter used for legacy timestamps. - * @param isParsing Whether the formatter is used for parsing (`true`) or for formatting (`false`). + * @param zoneId + * The time zone ID in which timestamps should be formatted or parsed. + * @param locale + * The locale overrides the system locale and is used in formatting. + * @param legacyFormat + * Defines the formatter used for legacy timestamps. + * @param isParsing + * Whether the formatter is used for parsing (`true`) or for formatting (`false`). */ class DefaultTimestampFormatter( zoneId: ZoneId, locale: Locale, legacyFormat: LegacyDateFormat = LENIENT_SIMPLE_DATE_FORMAT, isParsing: Boolean) - extends Iso8601TimestampFormatter( - TimestampFormatter.defaultPattern(), zoneId, locale, legacyFormat, isParsing) { + extends Iso8601TimestampFormatter( + TimestampFormatter.defaultPattern(), + zoneId, + locale, + legacyFormat, + isParsing) { override def parse(s: String): Long = { try { @@ -299,7 +335,9 @@ class DefaultTimestampFormatter( val utf8Value = UTF8String.fromString(s) SparkDateTimeUtils.stringToTimestampWithoutTimeZone(utf8Value, allowTimeZone).getOrElse { throw ExecutionErrors.cannotParseStringAsDataTypeError( - TimestampFormatter.defaultPattern(), s, TimestampNTZType) + TimestampFormatter.defaultPattern(), + s, + TimestampNTZType) } } catch checkParsedDiff(s, legacyFormatter.parse) } @@ -311,20 +349,21 @@ class DefaultTimestampFormatter( } /** - * The formatter parses/formats timestamps according to the pattern `yyyy-MM-dd HH:mm:ss.[..fff..]` - * where `[..fff..]` is a fraction of second up to microsecond resolution. The formatter does not - * output trailing zeros in the fraction. For example, the timestamp `2019-03-05 15:00:01.123400` is - * formatted as the string `2019-03-05 15:00:01.1234`. + * The formatter parses/formats timestamps according to the pattern `yyyy-MM-dd + * HH:mm:ss.[..fff..]` where `[..fff..]` is a fraction of second up to microsecond resolution. The + * formatter does not output trailing zeros in the fraction. For example, the timestamp + * `2019-03-05 15:00:01.123400` is formatted as the string `2019-03-05 15:00:01.1234`. * - * @param zoneId the time zone identifier in which the formatter parses or format timestamps + * @param zoneId + * the time zone identifier in which the formatter parses or format timestamps */ class FractionTimestampFormatter(zoneId: ZoneId) - extends Iso8601TimestampFormatter( - TimestampFormatter.defaultPattern(), - zoneId, - TimestampFormatter.defaultLocale, - LegacyDateFormats.FAST_DATE_FORMAT, - isParsing = false) { + extends Iso8601TimestampFormatter( + TimestampFormatter.defaultPattern(), + zoneId, + TimestampFormatter.defaultLocale, + LegacyDateFormats.FAST_DATE_FORMAT, + isParsing = false) { @transient override protected lazy val formatter = DateTimeFormatterHelper.fractionFormatter @@ -366,16 +405,14 @@ class FractionTimestampFormatter(zoneId: ZoneId) } /** - * The custom sub-class of `GregorianCalendar` is needed to get access to - * protected `fields` immediately after parsing. We cannot use - * the `get()` method because it performs normalization of the fraction - * part. Accordingly, the `MILLISECOND` field doesn't contain original value. + * The custom sub-class of `GregorianCalendar` is needed to get access to protected `fields` + * immediately after parsing. We cannot use the `get()` method because it performs normalization + * of the fraction part. Accordingly, the `MILLISECOND` field doesn't contain original value. * - * Also this class allows to set raw value to the `MILLISECOND` field - * directly before formatting. + * Also this class allows to set raw value to the `MILLISECOND` field directly before formatting. */ class MicrosCalendar(tz: TimeZone, digitsInFraction: Int) - extends GregorianCalendar(tz, Locale.US) { + extends GregorianCalendar(tz, Locale.US) { // Converts parsed `MILLISECOND` field to seconds fraction in microsecond precision. // For example if the fraction pattern is `SSSS` then `digitsInFraction` = 4, and // if the `MILLISECOND` field was parsed to `1234`. @@ -397,16 +434,13 @@ class MicrosCalendar(tz: TimeZone, digitsInFraction: Int) } } -class LegacyFastTimestampFormatter( - pattern: String, - zoneId: ZoneId, - locale: Locale) extends TimestampFormatter { +class LegacyFastTimestampFormatter(pattern: String, zoneId: ZoneId, locale: Locale) + extends TimestampFormatter { @transient private lazy val fastDateFormat = FastDateFormat.getInstance(pattern, TimeZone.getTimeZone(zoneId), locale) - @transient private lazy val cal = new MicrosCalendar( - fastDateFormat.getTimeZone, - fastDateFormat.getPattern.count(_ == 'S')) + @transient private lazy val cal = + new MicrosCalendar(fastDateFormat.getTimeZone, fastDateFormat.getPattern.count(_ == 'S')) override def parse(s: String): Long = { cal.clear() // Clear the calendar because it can be re-used many times @@ -464,7 +498,8 @@ class LegacySimpleTimestampFormatter( pattern: String, zoneId: ZoneId, locale: Locale, - lenient: Boolean = true) extends TimestampFormatter { + lenient: Boolean = true) + extends TimestampFormatter { @transient private lazy val sdf = { val formatter = new SimpleDateFormat(pattern, locale) formatter.setTimeZone(TimeZone.getTimeZone(zoneId)) @@ -586,14 +621,10 @@ object TimestampFormatter { legacyFormat: LegacyDateFormat, isParsing: Boolean, forTimestampNTZ: Boolean): TimestampFormatter = { - getFormatter(Some(format), zoneId, defaultLocale, legacyFormat, isParsing, - forTimestampNTZ) + getFormatter(Some(format), zoneId, defaultLocale, legacyFormat, isParsing, forTimestampNTZ) } - def apply( - format: String, - zoneId: ZoneId, - isParsing: Boolean): TimestampFormatter = { + def apply(format: String, zoneId: ZoneId, isParsing: Boolean): TimestampFormatter = { getFormatter(Some(format), zoneId, isParsing = isParsing) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/UDTUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/UDTUtils.scala index a98aa26d02ef7..73ab43f04a5a0 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/UDTUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/UDTUtils.scala @@ -28,11 +28,12 @@ import org.apache.spark.util.SparkClassUtils * with catalyst because they (amongst others) require access to Spark SQLs internal data * representation. * - * This interface and its companion object provide an escape hatch for working with UDTs from within - * the api project (e.g. Row.toJSON). The companion will try to bind to an implementation of the - * interface in catalyst, if none is found it will bind to [[DefaultUDTUtils]]. + * This interface and its companion object provide an escape hatch for working with UDTs from + * within the api project (e.g. Row.toJSON). The companion will try to bind to an implementation + * of the interface in catalyst, if none is found it will bind to [[DefaultUDTUtils]]. */ private[sql] trait UDTUtils { + /** * Convert the UDT instance to something that is compatible with [[org.apache.spark.sql.Row]]. * The returned value must conform to the schema of the UDT. @@ -41,13 +42,14 @@ private[sql] trait UDTUtils { } private[sql] object UDTUtils extends UDTUtils { - private val delegate = try { - val cls = SparkClassUtils.classForName("org.apache.spark.sql.catalyst.util.UDTUtilsImpl") - cls.getConstructor().newInstance().asInstanceOf[UDTUtils] - } catch { - case NonFatal(_) => - DefaultUDTUtils - } + private val delegate = + try { + val cls = SparkClassUtils.classForName("org.apache.spark.sql.catalyst.util.UDTUtilsImpl") + cls.getConstructor().newInstance().asInstanceOf[UDTUtils] + } catch { + case NonFatal(_) => + DefaultUDTUtils + } override def toRow(value: Any, udt: UserDefinedType[Any]): Any = delegate.toRow(value, udt) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/CompilationErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/CompilationErrors.scala index 6034c41906313..3e63b8281f739 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/CompilationErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/CompilationErrors.scala @@ -23,9 +23,7 @@ private[sql] trait CompilationErrors extends DataTypeErrorsBase { def ambiguousColumnOrFieldError(name: Seq[String], numMatches: Int): AnalysisException = { new AnalysisException( errorClass = "AMBIGUOUS_COLUMN_OR_FIELD", - messageParameters = Map( - "name" -> toSQLId(name), - "n" -> numMatches.toString)) + messageParameters = Map("name" -> toSQLId(name), "n" -> numMatches.toString)) } def columnNotFoundError(colName: String): AnalysisException = { @@ -51,9 +49,7 @@ private[sql] trait CompilationErrors extends DataTypeErrorsBase { } def usingUntypedScalaUDFError(): Throwable = { - new AnalysisException( - errorClass = "UNTYPED_SCALA_UDF", - messageParameters = Map.empty) + new AnalysisException(errorClass = "UNTYPED_SCALA_UDF", messageParameters = Map.empty) } def invalidBoundaryStartError(start: Long): Throwable = { @@ -81,14 +77,11 @@ private[sql] trait CompilationErrors extends DataTypeErrorsBase { def invalidSaveModeError(saveMode: String): Throwable = { new AnalysisException( errorClass = "INVALID_SAVE_MODE", - messageParameters = Map("mode" -> toDSOption(saveMode)) - ) + messageParameters = Map("mode" -> toDSOption(saveMode))) } def sortByWithoutBucketingError(): Throwable = { - new AnalysisException( - errorClass = "SORT_BY_WITHOUT_BUCKETING", - messageParameters = Map.empty) + new AnalysisException(errorClass = "SORT_BY_WITHOUT_BUCKETING", messageParameters = Map.empty) } def bucketByUnsupportedByOperationError(operation: String): Throwable = { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala index 32b4198dc1a63..388a98569258b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.types.{DataType, Decimal, StringType} import org.apache.spark.unsafe.types.UTF8String /** - * Object for grouping error messages from (most) exceptions thrown during query execution. - * This does not include exceptions thrown during the eager execution of commands, which are - * grouped into [[CompilationErrors]]. + * Object for grouping error messages from (most) exceptions thrown during query execution. This + * does not include exceptions thrown during the eager execution of commands, which are grouped + * into [[CompilationErrors]]. */ private[sql] object DataTypeErrors extends DataTypeErrorsBase { def unsupportedOperationExceptionError(): SparkUnsupportedOperationException = { @@ -35,13 +35,12 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { } def decimalPrecisionExceedsMaxPrecisionError( - precision: Int, maxPrecision: Int): SparkArithmeticException = { + precision: Int, + maxPrecision: Int): SparkArithmeticException = { new SparkArithmeticException( errorClass = "DECIMAL_PRECISION_EXCEEDS_MAX_PRECISION", - messageParameters = Map( - "precision" -> precision.toString, - "maxPrecision" -> maxPrecision.toString - ), + messageParameters = + Map("precision" -> precision.toString, "maxPrecision" -> maxPrecision.toString), context = Array.empty, summary = "") } @@ -53,8 +52,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { def outOfDecimalTypeRangeError(str: UTF8String): SparkArithmeticException = { new SparkArithmeticException( errorClass = "NUMERIC_OUT_OF_SUPPORTED_RANGE", - messageParameters = Map( - "value" -> str.toString), + messageParameters = Map("value" -> str.toString), context = Array.empty, summary = "") } @@ -68,25 +66,20 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { def nullLiteralsCannotBeCastedError(name: String): SparkUnsupportedOperationException = { new SparkUnsupportedOperationException( errorClass = "_LEGACY_ERROR_TEMP_2226", - messageParameters = Map( - "name" -> name)) + messageParameters = Map("name" -> name)) } def notUserDefinedTypeError(name: String, userClass: String): Throwable = { new SparkException( errorClass = "_LEGACY_ERROR_TEMP_2227", - messageParameters = Map( - "name" -> name, - "userClass" -> userClass), + messageParameters = Map("name" -> name, "userClass" -> userClass), cause = null) } def cannotLoadUserDefinedTypeError(name: String, userClass: String): Throwable = { new SparkException( errorClass = "_LEGACY_ERROR_TEMP_2228", - messageParameters = Map( - "name" -> name, - "userClass" -> userClass), + messageParameters = Map("name" -> name, "userClass" -> userClass), cause = null) } @@ -99,50 +92,42 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { def schemaFailToParseError(schema: String, e: Throwable): Throwable = { new AnalysisException( errorClass = "INVALID_SCHEMA.PARSE_ERROR", - messageParameters = Map( - "inputSchema" -> toSQLSchema(schema), - "reason" -> e.getMessage - ), + messageParameters = Map("inputSchema" -> toSQLSchema(schema), "reason" -> e.getMessage), cause = Some(e)) } def invalidDayTimeIntervalType(startFieldName: String, endFieldName: String): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1224", - messageParameters = Map( - "startFieldName" -> startFieldName, - "endFieldName" -> endFieldName)) + messageParameters = Map("startFieldName" -> startFieldName, "endFieldName" -> endFieldName)) } def invalidDayTimeField(field: Byte, supportedIds: Seq[String]): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1223", - messageParameters = Map( - "field" -> field.toString, - "supportedIds" -> supportedIds.mkString(", "))) + messageParameters = + Map("field" -> field.toString, "supportedIds" -> supportedIds.mkString(", "))) } def invalidYearMonthField(field: Byte, supportedIds: Seq[String]): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1225", - messageParameters = Map( - "field" -> field.toString, - "supportedIds" -> supportedIds.mkString(", "))) + messageParameters = + Map("field" -> field.toString, "supportedIds" -> supportedIds.mkString(", "))) } def decimalCannotGreaterThanPrecisionError(scale: Int, precision: Int): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1228", - messageParameters = Map( - "scale" -> scale.toString, - "precision" -> precision.toString)) + messageParameters = Map("scale" -> scale.toString, "precision" -> precision.toString)) } def negativeScaleNotAllowedError(scale: Int): Throwable = { val sqlConf = QuotingUtils.toSQLConf("spark.sql.legacy.allowNegativeScaleOfDecimal") - SparkException.internalError(s"Negative scale is not allowed: ${scale.toString}." + - s" Set the config ${sqlConf}" + - " to \"true\" to allow it.") + SparkException.internalError( + s"Negative scale is not allowed: ${scale.toString}." + + s" Set the config ${sqlConf}" + + " to \"true\" to allow it.") } def attributeNameSyntaxError(name: String): Throwable = { @@ -154,19 +139,17 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { def cannotMergeIncompatibleDataTypesError(left: DataType, right: DataType): Throwable = { new SparkException( errorClass = "CANNOT_MERGE_INCOMPATIBLE_DATA_TYPE", - messageParameters = Map( - "left" -> toSQLType(left), - "right" -> toSQLType(right)), + messageParameters = Map("left" -> toSQLType(left), "right" -> toSQLType(right)), cause = null) } def cannotMergeDecimalTypesWithIncompatibleScaleError( - leftScale: Int, rightScale: Int): Throwable = { + leftScale: Int, + rightScale: Int): Throwable = { new SparkException( errorClass = "_LEGACY_ERROR_TEMP_2124", - messageParameters = Map( - "leftScale" -> leftScale.toString(), - "rightScale" -> rightScale.toString()), + messageParameters = + Map("leftScale" -> leftScale.toString(), "rightScale" -> rightScale.toString()), cause = null) } @@ -179,9 +162,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { def invalidFieldName(fieldName: Seq[String], path: Seq[String], context: Origin): Throwable = { new AnalysisException( errorClass = "INVALID_FIELD_NAME", - messageParameters = Map( - "fieldName" -> toSQLId(fieldName), - "path" -> toSQLId(path)), + messageParameters = Map("fieldName" -> toSQLId(fieldName), "path" -> toSQLId(path)), origin = context) } @@ -227,30 +208,26 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { messageParameters = Map( "expression" -> convertedValueStr, "sourceType" -> toSQLType(StringType), - "targetType" -> toSQLType(to), - "ansiConfig" -> toSQLConf("spark.sql.ansi.enabled")), + "targetType" -> toSQLType(to)), context = getQueryContext(context), summary = getSummary(context)) } def ambiguousColumnOrFieldError( - name: Seq[String], numMatches: Int, context: Origin): Throwable = { + name: Seq[String], + numMatches: Int, + context: Origin): Throwable = { new AnalysisException( errorClass = "AMBIGUOUS_COLUMN_OR_FIELD", - messageParameters = Map( - "name" -> toSQLId(name), - "n" -> numMatches.toString), + messageParameters = Map("name" -> toSQLId(name), "n" -> numMatches.toString), origin = context) } def castingCauseOverflowError(t: String, from: DataType, to: DataType): ArithmeticException = { new SparkArithmeticException( errorClass = "CAST_OVERFLOW", - messageParameters = Map( - "value" -> t, - "sourceType" -> toSQLType(from), - "targetType" -> toSQLType(to), - "ansiConfig" -> toSQLConf("spark.sql.ansi.enabled")), + messageParameters = + Map("value" -> t, "sourceType" -> toSQLType(from), "targetType" -> toSQLType(to)), context = Array.empty, summary = "") } @@ -267,15 +244,13 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { messageParameters = Map( "methodName" -> "fieldIndex", "className" -> "Row", - "fieldName" -> toSQLId(fieldName)) - ) + "fieldName" -> toSQLId(fieldName))) } def valueIsNullError(index: Int): Throwable = { new SparkRuntimeException( errorClass = "ROW_VALUE_IS_NULL", - messageParameters = Map( - "index" -> index.toString), + messageParameters = Map("index" -> index.toString), cause = null) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala index 7910c386fcf14..698a7b096e1a5 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala @@ -60,7 +60,8 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { } def failToRecognizePatternAfterUpgradeError( - pattern: String, e: Throwable): SparkUpgradeException = { + pattern: String, + e: Throwable): SparkUpgradeException = { new SparkUpgradeException( errorClass = "INCONSISTENT_BEHAVIOR_CROSS_VERSION.DATETIME_PATTERN_RECOGNITION", messageParameters = Map( @@ -73,9 +74,8 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { def failToRecognizePatternError(pattern: String, e: Throwable): SparkRuntimeException = { new SparkRuntimeException( errorClass = "_LEGACY_ERROR_TEMP_2130", - messageParameters = Map( - "pattern" -> toSQLValue(pattern), - "docroot" -> SparkBuildInfo.spark_doc_root), + messageParameters = + Map("pattern" -> toSQLValue(pattern), "docroot" -> SparkBuildInfo.spark_doc_root), cause = e) } @@ -93,9 +93,9 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { } def invalidInputInCastToDatetimeError( - value: Double, - to: DataType, - context: QueryContext): SparkDateTimeException = { + value: Double, + to: DataType, + context: QueryContext): SparkDateTimeException = { invalidInputInCastToDatetimeErrorInternal(toSQLValue(value), DoubleType, to, context) } @@ -109,8 +109,7 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { messageParameters = Map( "expression" -> sqlValue, "sourceType" -> toSQLType(from), - "targetType" -> toSQLType(to), - "ansiConfig" -> toSQLConf(SqlApiConf.ANSI_ENABLED_KEY)), + "targetType" -> toSQLType(to)), context = getQueryContext(context), summary = getSummary(context)) } @@ -132,8 +131,10 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { summary = getSummary(context)) } - def cannotParseStringAsDataTypeError(pattern: String, value: String, dataType: DataType) - : Throwable = { + def cannotParseStringAsDataTypeError( + pattern: String, + value: String, + dataType: DataType): Throwable = { SparkException.internalError( s"Cannot parse field value ${toSQLValue(value)} for pattern ${toSQLValue(pattern)} " + s"as the target spark data type ${toSQLType(dataType)}.") @@ -161,17 +162,14 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { def userDefinedTypeNotAnnotatedAndRegisteredError(udt: UserDefinedType[_]): Throwable = { new SparkException( errorClass = "_LEGACY_ERROR_TEMP_2155", - messageParameters = Map( - "userClass" -> udt.userClass.getName), + messageParameters = Map("userClass" -> udt.userClass.getName), cause = null) } def cannotFindEncoderForTypeError(typeName: String): SparkUnsupportedOperationException = { new SparkUnsupportedOperationException( errorClass = "ENCODER_NOT_FOUND", - messageParameters = Map( - "typeName" -> typeName, - "docroot" -> SparkBuildInfo.spark_doc_root)) + messageParameters = Map("typeName" -> typeName, "docroot" -> SparkBuildInfo.spark_doc_root)) } def cannotHaveCircularReferencesInBeanClassError( @@ -184,8 +182,7 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { def cannotFindConstructorForTypeError(tpe: String): SparkUnsupportedOperationException = { new SparkUnsupportedOperationException( errorClass = "_LEGACY_ERROR_TEMP_2144", - messageParameters = Map( - "tpe" -> tpe)) + messageParameters = Map("tpe" -> tpe)) } def cannotHaveCircularReferencesInClassError(t: String): SparkUnsupportedOperationException = { @@ -195,12 +192,12 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { } def cannotUseInvalidJavaIdentifierAsFieldNameError( - fieldName: String, walkedTypePath: WalkedTypePath): SparkUnsupportedOperationException = { + fieldName: String, + walkedTypePath: WalkedTypePath): SparkUnsupportedOperationException = { new SparkUnsupportedOperationException( errorClass = "_LEGACY_ERROR_TEMP_2140", - messageParameters = Map( - "fieldName" -> fieldName, - "walkedTypePath" -> walkedTypePath.toString)) + messageParameters = + Map("fieldName" -> fieldName, "walkedTypePath" -> walkedTypePath.toString)) } def primaryConstructorNotFoundError(cls: Class[_]): SparkRuntimeException = { @@ -213,8 +210,33 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { def cannotGetOuterPointerForInnerClassError(innerCls: Class[_]): SparkRuntimeException = { new SparkRuntimeException( errorClass = "_LEGACY_ERROR_TEMP_2154", + messageParameters = Map("innerCls" -> innerCls.getName)) + } + + def cannotUseKryoSerialization(): SparkRuntimeException = { + new SparkRuntimeException(errorClass = "CANNOT_USE_KRYO", messageParameters = Map.empty) + } + + def notPublicClassError(name: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2229", + messageParameters = Map("name" -> name)) + } + + def primitiveTypesNotSupportedError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException(errorClass = "_LEGACY_ERROR_TEMP_2230") + } + + def elementsOfTupleExceedLimitError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_2150") + } + + def invalidAgnosticEncoderError(encoder: AnyRef): Throwable = { + new SparkRuntimeException( + errorClass = "INVALID_AGNOSTIC_ENCODER", messageParameters = Map( - "innerCls" -> innerCls.getName)) + "encoderType" -> encoder.getClass.getName, + "docroot" -> SparkBuildInfo.spark_doc_root)) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index e7ae9f2bfb7bb..b19607a28f06c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.trees.Origin /** - * Object for grouping all error messages of the query parsing. - * Currently it includes all ParseException. + * Object for grouping all error messages of the query parsing. Currently it includes all + * ParseException. */ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { @@ -37,9 +37,7 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { } def parserStackOverflow(parserRuleContext: ParserRuleContext): Throwable = { - throw new ParseException( - errorClass = "FAILED_TO_PARSE_TOO_COMPLEX", - ctx = parserRuleContext) + throw new ParseException(errorClass = "FAILED_TO_PARSE_TOO_COMPLEX", ctx = parserRuleContext) } def insertOverwriteDirectoryUnsupportedError(): Throwable = { @@ -160,7 +158,9 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { } def incompatibleJoinTypesError( - joinType1: String, joinType2: String, ctx: ParserRuleContext): Throwable = { + joinType1: String, + joinType2: String, + ctx: ParserRuleContext): Throwable = { new ParseException( errorClass = "INCOMPATIBLE_JOIN_TYPES", messageParameters = Map( @@ -209,13 +209,12 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { } def cannotParseValueTypeError( - valueType: String, value: String, ctx: TypeConstructorContext): Throwable = { + valueType: String, + value: String, + ctx: TypeConstructorContext): Throwable = { new ParseException( errorClass = "INVALID_TYPED_LITERAL", - messageParameters = Map( - "valueType" -> toSQLType(valueType), - "value" -> toSQLValue(value) - ), + messageParameters = Map("valueType" -> toSQLType(valueType), "value" -> toSQLValue(value)), ctx) } @@ -231,8 +230,12 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { ctx) } - def invalidNumericLiteralRangeError(rawStrippedQualifier: String, minValue: BigDecimal, - maxValue: BigDecimal, typeName: String, ctx: NumberContext): Throwable = { + def invalidNumericLiteralRangeError( + rawStrippedQualifier: String, + minValue: BigDecimal, + maxValue: BigDecimal, + typeName: String, + ctx: NumberContext): Throwable = { new ParseException( errorClass = "INVALID_NUMERIC_LITERAL_RANGE", messageParameters = Map( @@ -259,7 +262,9 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { } def fromToIntervalUnsupportedError( - from: String, to: String, ctx: ParserRuleContext): Throwable = { + from: String, + to: String, + ctx: ParserRuleContext): Throwable = { new ParseException( errorClass = "_LEGACY_ERROR_TEMP_0028", messageParameters = Map("from" -> from, "to" -> to), @@ -288,7 +293,8 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { } def nestedTypeMissingElementTypeError( - dataType: String, ctx: PrimitiveDataTypeContext): Throwable = { + dataType: String, + ctx: PrimitiveDataTypeContext): Throwable = { dataType.toUpperCase(Locale.ROOT) match { case "ARRAY" => new ParseException( @@ -309,23 +315,25 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { } def partitionTransformNotExpectedError( - name: String, expr: String, ctx: ApplyTransformContext): Throwable = { + name: String, + expr: String, + ctx: ApplyTransformContext): Throwable = { new ParseException( errorClass = "INVALID_SQL_SYNTAX.INVALID_COLUMN_REFERENCE", - messageParameters = Map( - "transform" -> toSQLId(name), - "expr" -> expr), + messageParameters = Map("transform" -> toSQLId(name), "expr" -> expr), ctx) } def wrongNumberArgumentsForTransformError( - name: String, actualNum: Int, ctx: ApplyTransformContext): Throwable = { + name: String, + actualNum: Int, + ctx: ApplyTransformContext): Throwable = { new ParseException( errorClass = "INVALID_SQL_SYNTAX.TRANSFORM_WRONG_NUM_ARGS", messageParameters = Map( "transform" -> toSQLId(name), - "expectedNum" -> "1", - "actualNum" -> actualNum.toString), + "expectedNum" -> "1", + "actualNum" -> actualNum.toString), ctx) } @@ -337,7 +345,9 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { } def cannotCleanReservedNamespacePropertyError( - property: String, ctx: ParserRuleContext, msg: String): Throwable = { + property: String, + ctx: ParserRuleContext, + msg: String): Throwable = { new ParseException( errorClass = "UNSUPPORTED_FEATURE.SET_NAMESPACE_PROPERTY", messageParameters = Map("property" -> property, "msg" -> msg), @@ -348,12 +358,13 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { new ParseException( errorClass = "UNSUPPORTED_FEATURE.SET_PROPERTIES_AND_DBPROPERTIES", messageParameters = Map.empty, - ctx - ) + ctx) } def cannotCleanReservedTablePropertyError( - property: String, ctx: ParserRuleContext, msg: String): Throwable = { + property: String, + ctx: ParserRuleContext, + msg: String): Throwable = { new ParseException( errorClass = "UNSUPPORTED_FEATURE.SET_TABLE_PROPERTY", messageParameters = Map("property" -> property, "msg" -> msg), @@ -361,12 +372,12 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { } def duplicatedTablePathsFoundError( - pathOne: String, pathTwo: String, ctx: ParserRuleContext): Throwable = { + pathOne: String, + pathTwo: String, + ctx: ParserRuleContext): Throwable = { new ParseException( errorClass = "_LEGACY_ERROR_TEMP_0032", - messageParameters = Map( - "pathOne" -> pathOne, - "pathTwo" -> pathTwo), + messageParameters = Map("pathOne" -> pathOne, "pathTwo" -> pathTwo), ctx) } @@ -374,15 +385,17 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { new ParseException(errorClass = "_LEGACY_ERROR_TEMP_0033", ctx) } - def operationInHiveStyleCommandUnsupportedError(operation: String, - command: String, ctx: StatementContext, msgOpt: Option[String] = None): Throwable = { + def operationInHiveStyleCommandUnsupportedError( + operation: String, + command: String, + ctx: StatementContext, + msgOpt: Option[String] = None): Throwable = { new ParseException( errorClass = "_LEGACY_ERROR_TEMP_0034", messageParameters = Map( "operation" -> operation, "command" -> command, - "msg" -> msgOpt.map(m => s", $m").getOrElse("") - ), + "msg" -> msgOpt.map(m => s", $m").getOrElse("")), ctx) } @@ -415,7 +428,8 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { } def addCatalogInCacheTableAsSelectNotAllowedError( - quoted: String, ctx: CacheTableContext): Throwable = { + quoted: String, + ctx: CacheTableContext): Throwable = { new ParseException( errorClass = "_LEGACY_ERROR_TEMP_0037", messageParameters = Map("quoted" -> quoted), @@ -479,22 +493,22 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { } def invalidPropertyKeyForSetQuotedConfigurationError( - keyCandidate: String, valueStr: String, ctx: ParserRuleContext): Throwable = { + keyCandidate: String, + valueStr: String, + ctx: ParserRuleContext): Throwable = { new ParseException( errorClass = "INVALID_PROPERTY_KEY", - messageParameters = Map( - "key" -> toSQLConf(keyCandidate), - "value" -> toSQLConf(valueStr)), + messageParameters = Map("key" -> toSQLConf(keyCandidate), "value" -> toSQLConf(valueStr)), ctx) } def invalidPropertyValueForSetQuotedConfigurationError( - valueCandidate: String, keyStr: String, ctx: ParserRuleContext): Throwable = { + valueCandidate: String, + keyStr: String, + ctx: ParserRuleContext): Throwable = { new ParseException( errorClass = "INVALID_PROPERTY_VALUE", - messageParameters = Map( - "value" -> toSQLConf(valueCandidate), - "key" -> toSQLConf(keyStr)), + messageParameters = Map("value" -> toSQLConf(valueCandidate), "key" -> toSQLConf(keyStr)), ctx) } @@ -542,12 +556,32 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { ctx) } + def identityColumnUnsupportedDataType( + ctx: IdentityColumnContext, + dataType: String): Throwable = { + new ParseException("IDENTITY_COLUMNS_UNSUPPORTED_DATA_TYPE", Map("dataType" -> dataType), ctx) + } + + def identityColumnIllegalStep(ctx: IdentityColSpecContext): Throwable = { + new ParseException("IDENTITY_COLUMNS_ILLEGAL_STEP", Map.empty, ctx) + } + + def identityColumnDuplicatedSequenceGeneratorOption( + ctx: IdentityColSpecContext, + sequenceGeneratorOption: String): Throwable = { + new ParseException( + "IDENTITY_COLUMNS_DUPLICATED_SEQUENCE_GENERATOR_OPTION", + Map("sequenceGeneratorOption" -> sequenceGeneratorOption), + ctx) + } + def createViewWithBothIfNotExistsAndReplaceError(ctx: CreateViewContext): Throwable = { new ParseException(errorClass = "_LEGACY_ERROR_TEMP_0052", ctx) } def temporaryViewWithSchemaBindingMode(ctx: StatementContext): Throwable = { - new ParseException(errorClass = "UNSUPPORTED_FEATURE.TEMPORARY_VIEW_WITH_SCHEMA_BINDING_MODE", + new ParseException( + errorClass = "UNSUPPORTED_FEATURE.TEMPORARY_VIEW_WITH_SCHEMA_BINDING_MODE", messageParameters = Map.empty, ctx) } @@ -581,9 +615,7 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { } def defineTempFuncWithIfNotExistsError(ctx: ParserRuleContext): Throwable = { - new ParseException( - errorClass = "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_IF_NOT_EXISTS", - ctx) + new ParseException(errorClass = "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_IF_NOT_EXISTS", ctx) } def unsupportedFunctionNameError(funcName: Seq[String], ctx: ParserRuleContext): Throwable = { @@ -632,9 +664,8 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { def invalidNameForDropTempFunc(name: Seq[String], ctx: ParserRuleContext): Throwable = { new ParseException( errorClass = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", - messageParameters = Map( - "statement" -> toSQLStmt("DROP TEMPORARY FUNCTION"), - "funcName" -> toSQLId(name)), + messageParameters = + Map("statement" -> toSQLStmt("DROP TEMPORARY FUNCTION"), "funcName" -> toSQLId(name)), ctx) } @@ -650,9 +681,7 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { new ParseException(errorClass = "REF_DEFAULT_VALUE_IS_NOT_ALLOWED_IN_PARTITION", ctx) } - def duplicateArgumentNamesError( - arguments: Seq[String], - ctx: ParserRuleContext): Throwable = { + def duplicateArgumentNamesError(arguments: Seq[String], ctx: ParserRuleContext): Throwable = { new ParseException( errorClass = "EXEC_IMMEDIATE_DUPLICATE_ARGUMENT_ALIASES", messageParameters = Map("aliases" -> arguments.map(toSQLId).mkString(", ")), @@ -679,12 +708,9 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { } new ParseException( errorClass = errorClass, - messageParameters = alterTypeMap ++ Map( - "columnName" -> columnName, - "optionName" -> optionName - ), - ctx - ) + messageParameters = + alterTypeMap ++ Map("columnName" -> columnName, "optionName" -> optionName), + ctx) } def invalidDatetimeUnitError( @@ -697,19 +723,17 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { "functionName" -> toSQLId(functionName), "parameter" -> toSQLId("unit"), "invalidValue" -> invalidValue), - ctx - ) + ctx) } def invalidTableFunctionIdentifierArgumentMissingParentheses( - ctx: ParserRuleContext, argumentName: String): Throwable = { + ctx: ParserRuleContext, + argumentName: String): Throwable = { new ParseException( errorClass = "INVALID_SQL_SYNTAX.INVALID_TABLE_FUNCTION_IDENTIFIER_ARGUMENT_MISSING_PARENTHESES", - messageParameters = Map( - "argumentName" -> toSQLId(argumentName)), - ctx - ) + messageParameters = Map("argumentName" -> toSQLId(argumentName)), + ctx) } def clusterByWithPartitionedBy(ctx: ParserRuleContext): Throwable = { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/api/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala index 146012b4266dd..7d8b33aa5e228 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -88,8 +88,8 @@ object ProcessingTimeTrigger { } /** - * A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at - * the specified interval. + * A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at the + * specified interval. */ case class ContinuousTrigger(intervalMs: Long) extends Trigger { Triggers.validate(intervalMs) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/api/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 1a2fbdc1fd116..5bdaebe3b073a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -43,9 +43,12 @@ import org.apache.spark.sql.internal.{InvokeInlineUserDefinedFunction, UserDefin * * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird * - * @tparam IN The input type for the aggregation. - * @tparam BUF The type of the intermediate value of the reduction. - * @tparam OUT The type of the final output result. + * @tparam IN + * The input type for the aggregation. + * @tparam BUF + * The type of the intermediate value of the reduction. + * @tparam OUT + * The type of the final output result. * @since 1.6.0 */ @SerialVersionUID(2093413866369130093L) @@ -58,7 +61,7 @@ abstract class Aggregator[-IN, BUF, OUT] extends Serializable with UserDefinedFu def zero: BUF /** - * Combine two values to produce a new value. For performance, the function may modify `b` and + * Combine two values to produce a new value. For performance, the function may modify `b` and * return it instead of constructing new object for b. * @since 1.6.0 */ @@ -93,8 +96,6 @@ abstract class Aggregator[-IN, BUF, OUT] extends Serializable with UserDefinedFu * @since 1.6.0 */ def toColumn: TypedColumn[IN, OUT] = { - new TypedColumn[IN, OUT]( - InvokeInlineUserDefinedFunction(this, Nil), - outputEncoder) + new TypedColumn[IN, OUT](InvokeInlineUserDefinedFunction(this, Nil), outputEncoder) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/api/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index c9e0e366a7447..6a22cbfaf351e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -52,8 +52,8 @@ sealed abstract class UserDefinedFunction extends UserDefinedFunctionLike { def nullable: Boolean /** - * Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the same - * input. + * Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the + * same input. * * @since 2.3.0 */ @@ -98,7 +98,8 @@ private[spark] case class SparkUserDefinedFunction( outputEncoder: Option[Encoder[_]] = None, givenName: Option[String] = None, nullable: Boolean = true, - deterministic: Boolean = true) extends UserDefinedFunction { + deterministic: Boolean = true) + extends UserDefinedFunction { override def withName(name: String): SparkUserDefinedFunction = { copy(givenName = Option(name)) @@ -169,7 +170,8 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT]( inputEncoder: Encoder[IN], givenName: Option[String] = None, nullable: Boolean = true, - deterministic: Boolean = true) extends UserDefinedFunction { + deterministic: Boolean = true) + extends UserDefinedFunction { override def withName(name: String): UserDefinedAggregator[IN, BUF, OUT] = { copy(givenName = Option(name)) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/api/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 9c4499ee243f5..dbe2da8f97341 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -32,9 +32,10 @@ import org.apache.spark.sql.Column * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) * }}} * - * @note When ordering is not defined, an unbounded window frame (rowFrame, unboundedPreceding, - * unboundedFollowing) is used by default. When ordering is defined, a growing window frame - * (rangeFrame, unboundedPreceding, currentRow) is used by default. + * @note + * When ordering is not defined, an unbounded window frame (rowFrame, unboundedPreceding, + * unboundedFollowing) is used by default. When ordering is defined, a growing window frame + * (rangeFrame, unboundedPreceding, currentRow) is used by default. * * @since 1.4.0 */ @@ -47,7 +48,7 @@ object Window { */ @scala.annotation.varargs def partitionBy(colName: String, colNames: String*): WindowSpec = { - spec.partitionBy(colName, colNames : _*) + spec.partitionBy(colName, colNames: _*) } /** @@ -56,7 +57,7 @@ object Window { */ @scala.annotation.varargs def partitionBy(cols: Column*): WindowSpec = { - spec.partitionBy(cols : _*) + spec.partitionBy(cols: _*) } /** @@ -65,7 +66,7 @@ object Window { */ @scala.annotation.varargs def orderBy(colName: String, colNames: String*): WindowSpec = { - spec.orderBy(colName, colNames : _*) + spec.orderBy(colName, colNames: _*) } /** @@ -74,12 +75,12 @@ object Window { */ @scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { - spec.orderBy(cols : _*) + spec.orderBy(cols: _*) } /** - * Value representing the first row in the partition, equivalent to "UNBOUNDED PRECEDING" in SQL. - * This can be used to specify the frame boundaries: + * Value representing the first row in the partition, equivalent to "UNBOUNDED PRECEDING" in + * SQL. This can be used to specify the frame boundaries: * * {{{ * Window.rowsBetween(Window.unboundedPreceding, Window.currentRow) @@ -113,22 +114,22 @@ object Window { def currentRow: Long = 0 /** - * Creates a [[WindowSpec]] with the frame boundaries defined, - * from `start` (inclusive) to `end` (inclusive). + * Creates a [[WindowSpec]] with the frame boundaries defined, from `start` (inclusive) to `end` + * (inclusive). * * Both `start` and `end` are positions relative to the current row. For example, "0" means * "current row", while "-1" means the row before the current row, and "5" means the fifth row * after the current row. * - * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, - * and `Window.currentRow` to specify special boundary values, rather than using integral - * values directly. + * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, and + * `Window.currentRow` to specify special boundary values, rather than using integral values + * directly. * - * A row based boundary is based on the position of the row within the partition. - * An offset indicates the number of rows above or below the current row, the frame for the - * current row starts or ends. For instance, given a row based sliding frame with a lower bound - * offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from - * index 4 to index 7. + * A row based boundary is based on the position of the row within the partition. An offset + * indicates the number of rows above or below the current row, the frame for the current row + * starts or ends. For instance, given a row based sliding frame with a lower bound offset of -1 + * and a upper bound offset of +2. The frame for row with index 5 would range from index 4 to + * index 7. * * {{{ * import org.apache.spark.sql.expressions.Window @@ -150,10 +151,12 @@ object Window { * +---+--------+---+ * }}} * - * @param start boundary start, inclusive. The frame is unbounded if this is - * the minimum long value (`Window.unboundedPreceding`). - * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * @param start + * boundary start, inclusive. The frame is unbounded if this is the minimum long value + * (`Window.unboundedPreceding`). + * @param end + * boundary end, inclusive. The frame is unbounded if this is the maximum long value + * (`Window.unboundedFollowing`). * @since 2.1.0 */ // Note: when updating the doc for this method, also update WindowSpec.rowsBetween. @@ -162,25 +165,24 @@ object Window { } /** - * Creates a [[WindowSpec]] with the frame boundaries defined, - * from `start` (inclusive) to `end` (inclusive). + * Creates a [[WindowSpec]] with the frame boundaries defined, from `start` (inclusive) to `end` + * (inclusive). * * Both `start` and `end` are relative to the current row. For example, "0" means "current row", - * while "-1" means one off before the current row, and "5" means the five off after the - * current row. + * while "-1" means one off before the current row, and "5" means the five off after the current + * row. * - * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, - * and `Window.currentRow` to specify special boundary values, rather than using long values + * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, and + * `Window.currentRow` to specify special boundary values, rather than using long values * directly. * - * A range-based boundary is based on the actual value of the ORDER BY - * expression(s). An offset is used to alter the value of the ORDER BY expression, - * for instance if the current ORDER BY expression has a value of 10 and the lower bound offset - * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a - * number of constraints on the ORDER BY expressions: there can be only one expression and this - * expression must have a numerical data type. An exception can be made when the offset is - * unbounded, because no value modification is needed, in this case multiple and non-numeric - * ORDER BY expression are allowed. + * A range-based boundary is based on the actual value of the ORDER BY expression(s). An offset + * is used to alter the value of the ORDER BY expression, for instance if the current ORDER BY + * expression has a value of 10 and the lower bound offset is -3, the resulting lower bound for + * the current row will be 10 - 3 = 7. This however puts a number of constraints on the ORDER BY + * expressions: there can be only one expression and this expression must have a numerical data + * type. An exception can be made when the offset is unbounded, because no value modification is + * needed, in this case multiple and non-numeric ORDER BY expression are allowed. * * {{{ * import org.apache.spark.sql.expressions.Window @@ -202,10 +204,12 @@ object Window { * +---+--------+---+ * }}} * - * @param start boundary start, inclusive. The frame is unbounded if this is - * the minimum long value (`Window.unboundedPreceding`). - * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * @param start + * boundary start, inclusive. The frame is unbounded if this is the minimum long value + * (`Window.unboundedPreceding`). + * @param end + * boundary end, inclusive. The frame is unbounded if this is the maximum long value + * (`Window.unboundedFollowing`). * @since 2.1.0 */ // Note: when updating the doc for this method, also update WindowSpec.rangeBetween. @@ -234,4 +238,4 @@ object Window { * @since 1.4.0 */ @Stable -class Window private() // So we can see Window in JavaDoc. +class Window private () // So we can see Window in JavaDoc. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/api/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index a888563d66a71..9abdee9c79ebc 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.internal.{ColumnNode, SortOrder, Window => EvalWindo * @since 1.4.0 */ @Stable -class WindowSpec private[sql]( +class WindowSpec private[sql] ( partitionSpec: Seq[ColumnNode], orderSpec: Seq[SortOrder], frame: Option[WindowFrame]) { @@ -78,15 +78,15 @@ class WindowSpec private[sql]( * "current row", while "-1" means the row before the current row, and "5" means the fifth row * after the current row. * - * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, - * and `Window.currentRow` to specify special boundary values, rather than using integral - * values directly. + * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, and + * `Window.currentRow` to specify special boundary values, rather than using integral values + * directly. * - * A row based boundary is based on the position of the row within the partition. - * An offset indicates the number of rows above or below the current row, the frame for the - * current row starts or ends. For instance, given a row based sliding frame with a lower bound - * offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from - * index 4 to index 7. + * A row based boundary is based on the position of the row within the partition. An offset + * indicates the number of rows above or below the current row, the frame for the current row + * starts or ends. For instance, given a row based sliding frame with a lower bound offset of -1 + * and a upper bound offset of +2. The frame for row with index 5 would range from index 4 to + * index 7. * * {{{ * import org.apache.spark.sql.expressions.Window @@ -108,10 +108,12 @@ class WindowSpec private[sql]( * +---+--------+---+ * }}} * - * @param start boundary start, inclusive. The frame is unbounded if this is - * the minimum long value (`Window.unboundedPreceding`). - * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * @param start + * boundary start, inclusive. The frame is unbounded if this is the minimum long value + * (`Window.unboundedPreceding`). + * @param end + * boundary end, inclusive. The frame is unbounded if this is the maximum long value + * (`Window.unboundedFollowing`). * @since 1.4.0 */ // Note: when updating the doc for this method, also update Window.rowsBetween. @@ -136,22 +138,21 @@ class WindowSpec private[sql]( /** * Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). * - * Both `start` and `end` are relative from the current row. For example, "0" means - * "current row", while "-1" means one off before the current row, and "5" means the five off - * after the current row. + * Both `start` and `end` are relative from the current row. For example, "0" means "current + * row", while "-1" means one off before the current row, and "5" means the five off after the + * current row. * - * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, - * and `Window.currentRow` to specify special boundary values, rather than using long values + * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, and + * `Window.currentRow` to specify special boundary values, rather than using long values * directly. * - * A range-based boundary is based on the actual value of the ORDER BY - * expression(s). An offset is used to alter the value of the ORDER BY expression, for - * instance if the current order by expression has a value of 10 and the lower bound offset - * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a - * number of constraints on the ORDER BY expressions: there can be only one expression and this - * expression must have a numerical data type. An exception can be made when the offset is - * unbounded, because no value modification is needed, in this case multiple and non-numeric - * ORDER BY expression are allowed. + * A range-based boundary is based on the actual value of the ORDER BY expression(s). An offset + * is used to alter the value of the ORDER BY expression, for instance if the current order by + * expression has a value of 10 and the lower bound offset is -3, the resulting lower bound for + * the current row will be 10 - 3 = 7. This however puts a number of constraints on the ORDER BY + * expressions: there can be only one expression and this expression must have a numerical data + * type. An exception can be made when the offset is unbounded, because no value modification is + * needed, in this case multiple and non-numeric ORDER BY expression are allowed. * * {{{ * import org.apache.spark.sql.expressions.Window @@ -173,10 +174,12 @@ class WindowSpec private[sql]( * +---+--------+---+ * }}} * - * @param start boundary start, inclusive. The frame is unbounded if this is - * the minimum long value (`Window.unboundedPreceding`). - * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * @param start + * boundary start, inclusive. The frame is unbounded if this is the minimum long value + * (`Window.unboundedPreceding`). + * @param end + * boundary end, inclusive. The frame is unbounded if this is the maximum long value + * (`Window.unboundedFollowing`). * @since 1.4.0 */ // Note: when updating the doc for this method, also update Window.rangeBetween. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/api/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index a4aa9c312aff2..5e7c993fae414 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -26,18 +26,21 @@ import org.apache.spark.sql.types._ * The base class for implementing user-defined aggregate functions (UDAF). * * @since 1.5.0 - * @deprecated UserDefinedAggregateFunction is deprecated. - * Aggregator[IN, BUF, OUT] should now be registered as a UDF via the functions.udaf(agg) method. + * @deprecated + * UserDefinedAggregateFunction is deprecated. Aggregator[IN, BUF, OUT] should now be registered + * as a UDF via the functions.udaf(agg) method. */ @Stable -@deprecated("Aggregator[IN, BUF, OUT] should now be registered as a UDF" + - " via the functions.udaf(agg) method.", "3.0.0") +@deprecated( + "Aggregator[IN, BUF, OUT] should now be registered as a UDF" + + " via the functions.udaf(agg) method.", + "3.0.0") abstract class UserDefinedAggregateFunction extends Serializable with UserDefinedFunctionLike { /** - * A `StructType` represents data types of input arguments of this aggregate function. - * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments - * with type of `DoubleType` and `LongType`, the returned `StructType` will look like + * A `StructType` represents data types of input arguments of this aggregate function. For + * example, if a [[UserDefinedAggregateFunction]] expects two input arguments with type of + * `DoubleType` and `LongType`, the returned `StructType` will look like * * ``` * new StructType() @@ -45,18 +48,17 @@ abstract class UserDefinedAggregateFunction extends Serializable with UserDefine * .add("longInput", LongType) * ``` * - * The name of a field of this `StructType` is only used to identify the corresponding - * input argument. Users can choose names to identify the input arguments. + * The name of a field of this `StructType` is only used to identify the corresponding input + * argument. Users can choose names to identify the input arguments. * * @since 1.5.0 */ def inputSchema: StructType /** - * A `StructType` represents data types of values in the aggregation buffer. - * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values - * (i.e. two intermediate values) with type of `DoubleType` and `LongType`, - * the returned `StructType` will look like + * A `StructType` represents data types of values in the aggregation buffer. For example, if a + * [[UserDefinedAggregateFunction]]'s buffer has two values (i.e. two intermediate values) with + * type of `DoubleType` and `LongType`, the returned `StructType` will look like * * ``` * new StructType() @@ -64,8 +66,8 @@ abstract class UserDefinedAggregateFunction extends Serializable with UserDefine * .add("longInput", LongType) * ``` * - * The name of a field of this `StructType` is only used to identify the corresponding - * buffer value. Users can choose names to identify the input arguments. + * The name of a field of this `StructType` is only used to identify the corresponding buffer + * value. Users can choose names to identify the input arguments. * * @since 1.5.0 */ @@ -79,8 +81,8 @@ abstract class UserDefinedAggregateFunction extends Serializable with UserDefine def dataType: DataType /** - * Returns true iff this function is deterministic, i.e. given the same input, - * always return the same output. + * Returns true iff this function is deterministic, i.e. given the same input, always return the + * same output. * * @since 1.5.0 */ @@ -90,8 +92,8 @@ abstract class UserDefinedAggregateFunction extends Serializable with UserDefine * Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer. * * The contract should be that applying the merge function on two initial buffers should just - * return the initial buffer itself, i.e. - * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`. + * return the initial buffer itself, i.e. `merge(initialBuffer, initialBuffer)` should equal + * `initialBuffer`. * * @since 1.5.0 */ @@ -134,8 +136,8 @@ abstract class UserDefinedAggregateFunction extends Serializable with UserDefine } /** - * Creates a `Column` for this UDAF using the distinct values of the given - * `Column`s as input arguments. + * Creates a `Column` for this UDAF using the distinct values of the given `Column`s as input + * arguments. * * @since 1.5.0 */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index e46d6c95b31ae..0662b8f2b271f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -104,9 +104,9 @@ object functions { /** * Creates a [[Column]] of literal value. * - * The passed in object is returned directly if it is already a [[Column]]. - * If the object is a Scala Symbol, it is converted into a [[Column]] also. - * Otherwise, a new [[Column]] is created to represent the literal value. + * The passed in object is returned directly if it is already a [[Column]]. If the object is a + * Scala Symbol, it is converted into a [[Column]] also. Otherwise, a new [[Column]] is created + * to represent the literal value. * * @group normal_funcs * @since 1.3.0 @@ -121,7 +121,7 @@ object functions { // method, `typedLit[Any](literal)` will always fail and fallback to `Literal.apply`. Hence, // we can just manually call `Literal.apply` to skip the expensive `ScalaReflection` code. // This is significantly better when there are many threads calling `lit` concurrently. - Column(internal.Literal(literal)) + Column(internal.Literal(literal)) } } @@ -133,26 +133,26 @@ object functions { * @group normal_funcs * @since 2.2.0 */ - def typedLit[T : TypeTag](literal: T): Column = { + def typedLit[T: TypeTag](literal: T): Column = { typedlit(literal) } /** * Creates a [[Column]] of literal value. * - * The passed in object is returned directly if it is already a [[Column]]. - * If the object is a Scala Symbol, it is converted into a [[Column]] also. - * Otherwise, a new [[Column]] is created to represent the literal value. - * The difference between this function and [[lit]] is that this function - * can handle parameterized scala types e.g.: List, Seq and Map. + * The passed in object is returned directly if it is already a [[Column]]. If the object is a + * Scala Symbol, it is converted into a [[Column]] also. Otherwise, a new [[Column]] is created + * to represent the literal value. The difference between this function and [[lit]] is that this + * function can handle parameterized scala types e.g.: List, Seq and Map. * - * @note `typedlit` will call expensive Scala reflection APIs. `lit` is preferred if parameterized - * Scala types are not used. + * @note + * `typedlit` will call expensive Scala reflection APIs. `lit` is preferred if parameterized + * Scala types are not used. * * @group normal_funcs * @since 3.2.0 */ - def typedlit[T : TypeTag](literal: T): Column = { + def typedlit[T: TypeTag](literal: T): Column = { literal match { case c: Column => c case s: Symbol => new ColumnName(s.name) @@ -178,8 +178,8 @@ object functions { def asc(columnName: String): Column = Column(columnName).asc /** - * Returns a sort expression based on ascending order of the column, - * and null values return before non-null values. + * Returns a sort expression based on ascending order of the column, and null values return + * before non-null values. * {{{ * df.sort(asc_nulls_first("dept"), desc("age")) * }}} @@ -190,8 +190,8 @@ object functions { def asc_nulls_first(columnName: String): Column = Column(columnName).asc_nulls_first /** - * Returns a sort expression based on ascending order of the column, - * and null values appear after non-null values. + * Returns a sort expression based on ascending order of the column, and null values appear + * after non-null values. * {{{ * df.sort(asc_nulls_last("dept"), desc("age")) * }}} @@ -213,8 +213,8 @@ object functions { def desc(columnName: String): Column = Column(columnName).desc /** - * Returns a sort expression based on the descending order of the column, - * and null values appear before non-null values. + * Returns a sort expression based on the descending order of the column, and null values appear + * before non-null values. * {{{ * df.sort(asc("dept"), desc_nulls_first("age")) * }}} @@ -225,8 +225,8 @@ object functions { def desc_nulls_first(columnName: String): Column = Column(columnName).desc_nulls_first /** - * Returns a sort expression based on the descending order of the column, - * and null values appear after non-null values. + * Returns a sort expression based on the descending order of the column, and null values appear + * after non-null values. * {{{ * df.sort(asc("dept"), desc_nulls_last("age")) * }}} @@ -236,7 +236,6 @@ object functions { */ def desc_nulls_last(columnName: String): Column = Column(columnName).desc_nulls_last - ////////////////////////////////////////////////////////////////////////////////////////////// // Aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// @@ -285,12 +284,14 @@ object functions { * @group agg_funcs * @since 2.1.0 */ - def approx_count_distinct(columnName: String): Column = approx_count_distinct(column(columnName)) + def approx_count_distinct(columnName: String): Column = approx_count_distinct( + column(columnName)) /** * Aggregate function: returns the approximate number of distinct items in a group. * - * @param rsd maximum relative standard deviation allowed (default = 0.05) + * @param rsd + * maximum relative standard deviation allowed (default = 0.05) * * @group agg_funcs * @since 2.1.0 @@ -302,7 +303,8 @@ object functions { /** * Aggregate function: returns the approximate number of distinct items in a group. * - * @param rsd maximum relative standard deviation allowed (default = 0.05) + * @param rsd + * maximum relative standard deviation allowed (default = 0.05) * * @group agg_funcs * @since 2.1.0 @@ -330,8 +332,9 @@ object functions { /** * Aggregate function: returns a list of objects with duplicates. * - * @note The function is non-deterministic because the order of collected results depends - * on the order of the rows which may be non-deterministic after a shuffle. + * @note + * The function is non-deterministic because the order of collected results depends on the + * order of the rows which may be non-deterministic after a shuffle. * * @group agg_funcs * @since 1.6.0 @@ -341,8 +344,9 @@ object functions { /** * Aggregate function: returns a list of objects with duplicates. * - * @note The function is non-deterministic because the order of collected results depends - * on the order of the rows which may be non-deterministic after a shuffle. + * @note + * The function is non-deterministic because the order of collected results depends on the + * order of the rows which may be non-deterministic after a shuffle. * * @group agg_funcs * @since 1.6.0 @@ -352,8 +356,9 @@ object functions { /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * - * @note The function is non-deterministic because the order of collected results depends - * on the order of the rows which may be non-deterministic after a shuffle. + * @note + * The function is non-deterministic because the order of collected results depends on the + * order of the rows which may be non-deterministic after a shuffle. * * @group agg_funcs * @since 1.6.0 @@ -363,8 +368,9 @@ object functions { /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * - * @note The function is non-deterministic because the order of collected results depends - * on the order of the rows which may be non-deterministic after a shuffle. + * @note + * The function is non-deterministic because the order of collected results depends on the + * order of the rows which may be non-deterministic after a shuffle. * * @group agg_funcs * @since 1.6.0 @@ -372,10 +378,10 @@ object functions { def collect_set(columnName: String): Column = collect_set(Column(columnName)) /** - * Returns a count-min sketch of a column with the given esp, confidence and seed. The result - * is an array of bytes, which can be deserialized to a `CountMinSketch` before usage. - * Count-min sketch is a probabilistic data structure used for cardinality estimation using - * sub-linear space. + * Returns a count-min sketch of a column with the given esp, confidence and seed. The result is + * an array of bytes, which can be deserialized to a `CountMinSketch` before usage. Count-min + * sketch is a probabilistic data structure used for cardinality estimation using sub-linear + * space. * * @group agg_funcs * @since 3.5.0 @@ -383,6 +389,18 @@ object functions { def count_min_sketch(e: Column, eps: Column, confidence: Column, seed: Column): Column = Column.fn("count_min_sketch", e, eps, confidence, seed) + /** + * Returns a count-min sketch of a column with the given esp, confidence and seed. The result is + * an array of bytes, which can be deserialized to a `CountMinSketch` before usage. Count-min + * sketch is a probabilistic data structure used for cardinality estimation using sub-linear + * space. + * + * @group agg_funcs + * @since 4.0.0 + */ + def count_min_sketch(e: Column, eps: Column, confidence: Column): Column = + count_min_sketch(e, eps, confidence, lit(SparkClassUtils.random.nextInt)) + private[spark] def collect_top_k(e: Column, num: Int, reverse: Boolean): Column = Column.internalFn("collect_top_k", e, lit(num), lit(reverse)) @@ -443,7 +461,7 @@ object functions { */ @scala.annotation.varargs def countDistinct(columnName: String, columnNames: String*): Column = - count_distinct(Column(columnName), columnNames.map(Column.apply) : _*) + count_distinct(Column(columnName), columnNames.map(Column.apply): _*) /** * Aggregate function: returns the number of distinct items in a group. @@ -499,8 +517,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * - * @note The function is non-deterministic because its results depends on the order of the rows - * which may be non-deterministic after a shuffle. + * @note + * The function is non-deterministic because its results depends on the order of the rows + * which may be non-deterministic after a shuffle. * * @group agg_funcs * @since 2.0.0 @@ -514,8 +533,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * - * @note The function is non-deterministic because its results depends on the order of the rows - * which may be non-deterministic after a shuffle. + * @note + * The function is non-deterministic because its results depends on the order of the rows + * which may be non-deterministic after a shuffle. * * @group agg_funcs * @since 2.0.0 @@ -530,8 +550,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * - * @note The function is non-deterministic because its results depends on the order of the rows - * which may be non-deterministic after a shuffle. + * @note + * The function is non-deterministic because its results depends on the order of the rows + * which may be non-deterministic after a shuffle. * * @group agg_funcs * @since 1.3.0 @@ -544,8 +565,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * - * @note The function is non-deterministic because its results depends on the order of the rows - * which may be non-deterministic after a shuffle. + * @note + * The function is non-deterministic because its results depends on the order of the rows + * which may be non-deterministic after a shuffle. * * @group agg_funcs * @since 1.3.0 @@ -555,8 +577,9 @@ object functions { /** * Aggregate function: returns the first value in a group. * - * @note The function is non-deterministic because its results depends on the order of the rows - * which may be non-deterministic after a shuffle. + * @note + * The function is non-deterministic because its results depends on the order of the rows + * which may be non-deterministic after a shuffle. * * @group agg_funcs * @since 3.5.0 @@ -569,8 +592,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * - * @note The function is non-deterministic because its results depends on the order of the rows - * which may be non-deterministic after a shuffle. + * @note + * The function is non-deterministic because its results depends on the order of the rows + * which may be non-deterministic after a shuffle. * * @group agg_funcs * @since 3.5.0 @@ -579,8 +603,8 @@ object functions { Column.fn("first_value", e, ignoreNulls) /** - * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated - * or not, returns 1 for aggregated or 0 for not aggregated in the result set. + * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated or + * not, returns 1 for aggregated or 0 for not aggregated in the result set. * * @group agg_funcs * @since 2.0.0 @@ -588,8 +612,8 @@ object functions { def grouping(e: Column): Column = Column.fn("grouping", e) /** - * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated - * or not, returns 1 for aggregated or 0 for not aggregated in the result set. + * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated or + * not, returns 1 for aggregated or 0 for not aggregated in the result set. * * @group agg_funcs * @since 2.0.0 @@ -603,8 +627,9 @@ object functions { * (grouping(c1) <<; (n-1)) + (grouping(c2) <<; (n-2)) + ... + grouping(cn) * }}} * - * @note The list of columns should match with grouping columns exactly, or empty (means all the - * grouping columns). + * @note + * The list of columns should match with grouping columns exactly, or empty (means all the + * grouping columns). * * @group agg_funcs * @since 2.0.0 @@ -618,18 +643,19 @@ object functions { * (grouping(c1) <<; (n-1)) + (grouping(c2) <<; (n-2)) + ... + grouping(cn) * }}} * - * @note The list of columns should match with grouping columns exactly. + * @note + * The list of columns should match with grouping columns exactly. * * @group agg_funcs * @since 2.0.0 */ def grouping_id(colName: String, colNames: String*): Column = { - grouping_id((Seq(colName) ++ colNames).map(n => Column(n)) : _*) + grouping_id((Seq(colName) ++ colNames).map(n => Column(n)): _*) } /** - * Aggregate function: returns the updatable binary representation of the Datasketches - * HllSketch configured with lgConfigK arg. + * Aggregate function: returns the updatable binary representation of the Datasketches HllSketch + * configured with lgConfigK arg. * * @group agg_funcs * @since 3.5.0 @@ -638,8 +664,8 @@ object functions { Column.fn("hll_sketch_agg", e, lgConfigK) /** - * Aggregate function: returns the updatable binary representation of the Datasketches - * HllSketch configured with lgConfigK arg. + * Aggregate function: returns the updatable binary representation of the Datasketches HllSketch + * configured with lgConfigK arg. * * @group agg_funcs * @since 3.5.0 @@ -648,8 +674,8 @@ object functions { Column.fn("hll_sketch_agg", e, lit(lgConfigK)) /** - * Aggregate function: returns the updatable binary representation of the Datasketches - * HllSketch configured with lgConfigK arg. + * Aggregate function: returns the updatable binary representation of the Datasketches HllSketch + * configured with lgConfigK arg. * * @group agg_funcs * @since 3.5.0 @@ -659,8 +685,8 @@ object functions { } /** - * Aggregate function: returns the updatable binary representation of the Datasketches - * HllSketch configured with default lgConfigK value. + * Aggregate function: returns the updatable binary representation of the Datasketches HllSketch + * configured with default lgConfigK value. * * @group agg_funcs * @since 3.5.0 @@ -669,8 +695,8 @@ object functions { Column.fn("hll_sketch_agg", e) /** - * Aggregate function: returns the updatable binary representation of the Datasketches - * HllSketch configured with default lgConfigK value. + * Aggregate function: returns the updatable binary representation of the Datasketches HllSketch + * configured with default lgConfigK value. * * @group agg_funcs * @since 3.5.0 @@ -681,9 +707,9 @@ object functions { /** * Aggregate function: returns the updatable binary representation of the Datasketches - * HllSketch, generated by merging previously created Datasketches HllSketch instances - * via a Datasketches Union instance. Throws an exception if sketches have different - * lgConfigK values and allowDifferentLgConfigK is set to false. + * HllSketch, generated by merging previously created Datasketches HllSketch instances via a + * Datasketches Union instance. Throws an exception if sketches have different lgConfigK values + * and allowDifferentLgConfigK is set to false. * * @group agg_funcs * @since 3.5.0 @@ -693,9 +719,9 @@ object functions { /** * Aggregate function: returns the updatable binary representation of the Datasketches - * HllSketch, generated by merging previously created Datasketches HllSketch instances - * via a Datasketches Union instance. Throws an exception if sketches have different - * lgConfigK values and allowDifferentLgConfigK is set to false. + * HllSketch, generated by merging previously created Datasketches HllSketch instances via a + * Datasketches Union instance. Throws an exception if sketches have different lgConfigK values + * and allowDifferentLgConfigK is set to false. * * @group agg_funcs * @since 3.5.0 @@ -705,9 +731,9 @@ object functions { /** * Aggregate function: returns the updatable binary representation of the Datasketches - * HllSketch, generated by merging previously created Datasketches HllSketch instances - * via a Datasketches Union instance. Throws an exception if sketches have different - * lgConfigK values and allowDifferentLgConfigK is set to false. + * HllSketch, generated by merging previously created Datasketches HllSketch instances via a + * Datasketches Union instance. Throws an exception if sketches have different lgConfigK values + * and allowDifferentLgConfigK is set to false. * * @group agg_funcs * @since 3.5.0 @@ -718,9 +744,8 @@ object functions { /** * Aggregate function: returns the updatable binary representation of the Datasketches - * HllSketch, generated by merging previously created Datasketches HllSketch instances - * via a Datasketches Union instance. Throws an exception if sketches have different - * lgConfigK values. + * HllSketch, generated by merging previously created Datasketches HllSketch instances via a + * Datasketches Union instance. Throws an exception if sketches have different lgConfigK values. * * @group agg_funcs * @since 3.5.0 @@ -730,9 +755,8 @@ object functions { /** * Aggregate function: returns the updatable binary representation of the Datasketches - * HllSketch, generated by merging previously created Datasketches HllSketch instances - * via a Datasketches Union instance. Throws an exception if sketches have different - * lgConfigK values. + * HllSketch, generated by merging previously created Datasketches HllSketch instances via a + * Datasketches Union instance. Throws an exception if sketches have different lgConfigK values. * * @group agg_funcs * @since 3.5.0 @@ -763,8 +787,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * - * @note The function is non-deterministic because its results depends on the order of the rows - * which may be non-deterministic after a shuffle. + * @note + * The function is non-deterministic because its results depends on the order of the rows + * which may be non-deterministic after a shuffle. * * @group agg_funcs * @since 2.0.0 @@ -778,8 +803,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * - * @note The function is non-deterministic because its results depends on the order of the rows - * which may be non-deterministic after a shuffle. + * @note + * The function is non-deterministic because its results depends on the order of the rows + * which may be non-deterministic after a shuffle. * * @group agg_funcs * @since 2.0.0 @@ -794,8 +820,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * - * @note The function is non-deterministic because its results depends on the order of the rows - * which may be non-deterministic after a shuffle. + * @note + * The function is non-deterministic because its results depends on the order of the rows + * which may be non-deterministic after a shuffle. * * @group agg_funcs * @since 1.3.0 @@ -808,8 +835,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * - * @note The function is non-deterministic because its results depends on the order of the rows - * which may be non-deterministic after a shuffle. + * @note + * The function is non-deterministic because its results depends on the order of the rows + * which may be non-deterministic after a shuffle. * * @group agg_funcs * @since 1.3.0 @@ -819,8 +847,9 @@ object functions { /** * Aggregate function: returns the last value in a group. * - * @note The function is non-deterministic because its results depends on the order of the rows - * which may be non-deterministic after a shuffle. + * @note + * The function is non-deterministic because its results depends on the order of the rows + * which may be non-deterministic after a shuffle. * * @group agg_funcs * @since 3.5.0 @@ -833,8 +862,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * - * @note The function is non-deterministic because its results depends on the order of the rows - * which may be non-deterministic after a shuffle. + * @note + * The function is non-deterministic because its results depends on the order of the rows + * which may be non-deterministic after a shuffle. * * @group agg_funcs * @since 3.5.0 @@ -881,8 +911,9 @@ object functions { /** * Aggregate function: returns the value associated with the maximum value of ord. * - * @note The function is non-deterministic so the output order can be different for - * those associated the same values of `e`. + * @note + * The function is non-deterministic so the output order can be different for those associated + * the same values of `e`. * * @group agg_funcs * @since 3.3.0 @@ -890,8 +921,7 @@ object functions { def max_by(e: Column, ord: Column): Column = Column.fn("max_by", e, ord) /** - * Aggregate function: returns the average of the values in a group. - * Alias for avg. + * Aggregate function: returns the average of the values in a group. Alias for avg. * * @group agg_funcs * @since 1.4.0 @@ -899,8 +929,7 @@ object functions { def mean(e: Column): Column = avg(e) /** - * Aggregate function: returns the average of the values in a group. - * Alias for avg. + * Aggregate function: returns the average of the values in a group. Alias for avg. * * @group agg_funcs * @since 1.4.0 @@ -934,8 +963,9 @@ object functions { /** * Aggregate function: returns the value associated with the minimum value of ord. * - * @note The function is non-deterministic so the output order can be different for - * those associated the same values of `e`. + * @note + * The function is non-deterministic so the output order can be different for those associated + * the same values of `e`. * * @group agg_funcs * @since 3.3.0 @@ -943,8 +973,8 @@ object functions { def min_by(e: Column, ord: Column): Column = Column.fn("min_by", e, ord) /** - * Aggregate function: returns the exact percentile(s) of numeric column `expr` at the - * given percentage(s) with value range in [0.0, 1.0]. + * Aggregate function: returns the exact percentile(s) of numeric column `expr` at the given + * percentage(s) with value range in [0.0, 1.0]. * * @group agg_funcs * @since 3.5.0 @@ -952,8 +982,8 @@ object functions { def percentile(e: Column, percentage: Column): Column = Column.fn("percentile", e, percentage) /** - * Aggregate function: returns the exact percentile(s) of numeric column `expr` at the - * given percentage(s) with value range in [0.0, 1.0]. + * Aggregate function: returns the exact percentile(s) of numeric column `expr` at the given + * percentage(s) with value range in [0.0, 1.0]. * * @group agg_funcs * @since 3.5.0 @@ -962,17 +992,16 @@ object functions { Column.fn("percentile", e, percentage, frequency) /** - * Aggregate function: returns the approximate `percentile` of the numeric column `col` which - * is the smallest value in the ordered `col` values (sorted from least to greatest) such that - * no more than `percentage` of `col` values is less than the value or equal to that value. + * Aggregate function: returns the approximate `percentile` of the numeric column `col` which is + * the smallest value in the ordered `col` values (sorted from least to greatest) such that no + * more than `percentage` of `col` values is less than the value or equal to that value. * - * If percentage is an array, each value must be between 0.0 and 1.0. - * If it is a single floating point value, it must be between 0.0 and 1.0. + * If percentage is an array, each value must be between 0.0 and 1.0. If it is a single floating + * point value, it must be between 0.0 and 1.0. * - * The accuracy parameter is a positive numeric literal - * which controls approximation accuracy at the cost of memory. - * Higher value of accuracy yields better accuracy, 1.0/accuracy - * is the relative error of the approximation. + * The accuracy parameter is a positive numeric literal which controls approximation accuracy at + * the cost of memory. Higher value of accuracy yields better accuracy, 1.0/accuracy is the + * relative error of the approximation. * * @group agg_funcs * @since 3.1.0 @@ -981,17 +1010,16 @@ object functions { Column.fn("percentile_approx", e, percentage, accuracy) /** - * Aggregate function: returns the approximate `percentile` of the numeric column `col` which - * is the smallest value in the ordered `col` values (sorted from least to greatest) such that - * no more than `percentage` of `col` values is less than the value or equal to that value. + * Aggregate function: returns the approximate `percentile` of the numeric column `col` which is + * the smallest value in the ordered `col` values (sorted from least to greatest) such that no + * more than `percentage` of `col` values is less than the value or equal to that value. * - * If percentage is an array, each value must be between 0.0 and 1.0. - * If it is a single floating point value, it must be between 0.0 and 1.0. + * If percentage is an array, each value must be between 0.0 and 1.0. If it is a single floating + * point value, it must be between 0.0 and 1.0. * - * The accuracy parameter is a positive numeric literal - * which controls approximation accuracy at the cost of memory. - * Higher value of accuracy yields better accuracy, 1.0/accuracy - * is the relative error of the approximation. + * The accuracy parameter is a positive numeric literal which controls approximation accuracy at + * the cost of memory. Higher value of accuracy yields better accuracy, 1.0/accuracy is the + * relative error of the approximation. * * @group agg_funcs * @since 3.5.0 @@ -1049,8 +1077,7 @@ object functions { def stddev(columnName: String): Column = stddev(Column(columnName)) /** - * Aggregate function: returns the sample standard deviation of - * the expression in a group. + * Aggregate function: returns the sample standard deviation of the expression in a group. * * @group agg_funcs * @since 1.6.0 @@ -1058,8 +1085,7 @@ object functions { def stddev_samp(e: Column): Column = Column.fn("stddev_samp", e) /** - * Aggregate function: returns the sample standard deviation of - * the expression in a group. + * Aggregate function: returns the sample standard deviation of the expression in a group. * * @group agg_funcs * @since 1.6.0 @@ -1067,8 +1093,7 @@ object functions { def stddev_samp(columnName: String): Column = stddev_samp(Column(columnName)) /** - * Aggregate function: returns the population standard deviation of - * the expression in a group. + * Aggregate function: returns the population standard deviation of the expression in a group. * * @group agg_funcs * @since 1.6.0 @@ -1076,8 +1101,7 @@ object functions { def stddev_pop(e: Column): Column = Column.fn("stddev_pop", e) /** - * Aggregate function: returns the population standard deviation of - * the expression in a group. + * Aggregate function: returns the population standard deviation of the expression in a group. * * @group agg_funcs * @since 1.6.0 @@ -1175,8 +1199,8 @@ object functions { def var_pop(columnName: String): Column = var_pop(Column(columnName)) /** - * Aggregate function: returns the average of the independent variable for non-null pairs - * in a group, where `y` is the dependent variable and `x` is the independent variable. + * Aggregate function: returns the average of the independent variable for non-null pairs in a + * group, where `y` is the dependent variable and `x` is the independent variable. * * @group agg_funcs * @since 3.5.0 @@ -1184,8 +1208,8 @@ object functions { def regr_avgx(y: Column, x: Column): Column = Column.fn("regr_avgx", y, x) /** - * Aggregate function: returns the average of the independent variable for non-null pairs - * in a group, where `y` is the dependent variable and `x` is the independent variable. + * Aggregate function: returns the average of the independent variable for non-null pairs in a + * group, where `y` is the dependent variable and `x` is the independent variable. * * @group agg_funcs * @since 3.5.0 @@ -1193,8 +1217,8 @@ object functions { def regr_avgy(y: Column, x: Column): Column = Column.fn("regr_avgy", y, x) /** - * Aggregate function: returns the number of non-null number pairs - * in a group, where `y` is the dependent variable and `x` is the independent variable. + * Aggregate function: returns the number of non-null number pairs in a group, where `y` is the + * dependent variable and `x` is the independent variable. * * @group agg_funcs * @since 3.5.0 @@ -1202,9 +1226,9 @@ object functions { def regr_count(y: Column, x: Column): Column = Column.fn("regr_count", y, x) /** - * Aggregate function: returns the intercept of the univariate linear regression line - * for non-null pairs in a group, where `y` is the dependent variable and - * `x` is the independent variable. + * Aggregate function: returns the intercept of the univariate linear regression line for + * non-null pairs in a group, where `y` is the dependent variable and `x` is the independent + * variable. * * @group agg_funcs * @since 3.5.0 @@ -1212,8 +1236,8 @@ object functions { def regr_intercept(y: Column, x: Column): Column = Column.fn("regr_intercept", y, x) /** - * Aggregate function: returns the coefficient of determination for non-null pairs - * in a group, where `y` is the dependent variable and `x` is the independent variable. + * Aggregate function: returns the coefficient of determination for non-null pairs in a group, + * where `y` is the dependent variable and `x` is the independent variable. * * @group agg_funcs * @since 3.5.0 @@ -1221,8 +1245,8 @@ object functions { def regr_r2(y: Column, x: Column): Column = Column.fn("regr_r2", y, x) /** - * Aggregate function: returns the slope of the linear regression line for non-null pairs - * in a group, where `y` is the dependent variable and `x` is the independent variable. + * Aggregate function: returns the slope of the linear regression line for non-null pairs in a + * group, where `y` is the dependent variable and `x` is the independent variable. * * @group agg_funcs * @since 3.5.0 @@ -1230,8 +1254,8 @@ object functions { def regr_slope(y: Column, x: Column): Column = Column.fn("regr_slope", y, x) /** - * Aggregate function: returns REGR_COUNT(y, x) * VAR_POP(x) for non-null pairs - * in a group, where `y` is the dependent variable and `x` is the independent variable. + * Aggregate function: returns REGR_COUNT(y, x) * VAR_POP(x) for non-null pairs in a group, + * where `y` is the dependent variable and `x` is the independent variable. * * @group agg_funcs * @since 3.5.0 @@ -1239,8 +1263,8 @@ object functions { def regr_sxx(y: Column, x: Column): Column = Column.fn("regr_sxx", y, x) /** - * Aggregate function: returns REGR_COUNT(y, x) * COVAR_POP(y, x) for non-null pairs - * in a group, where `y` is the dependent variable and `x` is the independent variable. + * Aggregate function: returns REGR_COUNT(y, x) * COVAR_POP(y, x) for non-null pairs in a group, + * where `y` is the dependent variable and `x` is the independent variable. * * @group agg_funcs * @since 3.5.0 @@ -1248,8 +1272,8 @@ object functions { def regr_sxy(y: Column, x: Column): Column = Column.fn("regr_sxy", y, x) /** - * Aggregate function: returns REGR_COUNT(y, x) * VAR_POP(y) for non-null pairs - * in a group, where `y` is the dependent variable and `x` is the independent variable. + * Aggregate function: returns REGR_COUNT(y, x) * VAR_POP(y) for non-null pairs in a group, + * where `y` is the dependent variable and `x` is the independent variable. * * @group agg_funcs * @since 3.5.0 @@ -1265,8 +1289,8 @@ object functions { def any_value(e: Column): Column = Column.fn("any_value", e) /** - * Aggregate function: returns some value of `e` for a group of rows. - * If `isIgnoreNull` is true, returns only non-null values. + * Aggregate function: returns some value of `e` for a group of rows. If `isIgnoreNull` is true, + * returns only non-null values. * * @group agg_funcs * @since 3.5.0 @@ -1283,13 +1307,12 @@ object functions { def count_if(e: Column): Column = Column.fn("count_if", e) /** - * Aggregate function: computes a histogram on numeric 'expr' using nb bins. - * The return value is an array of (x,y) pairs representing the centers of the - * histogram's bins. As the value of 'nb' is increased, the histogram approximation - * gets finer-grained, but may yield artifacts around outliers. In practice, 20-40 - * histogram bins appear to work well, with more bins being required for skewed or - * smaller datasets. Note that this function creates a histogram with non-uniform - * bin widths. It offers no guarantees in terms of the mean-squared-error of the + * Aggregate function: computes a histogram on numeric 'expr' using nb bins. The return value is + * an array of (x,y) pairs representing the centers of the histogram's bins. As the value of + * 'nb' is increased, the histogram approximation gets finer-grained, but may yield artifacts + * around outliers. In practice, 20-40 histogram bins appear to work well, with more bins being + * required for skewed or smaller datasets. Note that this function creates a histogram with + * non-uniform bin widths. It offers no guarantees in terms of the mean-squared-error of the * histogram, but in practice is comparable to the histograms produced by the R/S-Plus * statistical computing packages. Note: the output type of the 'x' field in the return value is * propagated from the input value consumed in the aggregate function. @@ -1386,10 +1409,10 @@ object functions { * Window function: returns the rank of rows within a window partition, without any gaps. * * The difference between rank and dense_rank is that denseRank leaves no gaps in ranking - * sequence when there are ties. That is, if you were ranking a competition using dense_rank - * and had three people tie for second place, you would say that all three were in second - * place and that the next person came in third. Rank would give me sequential numbers, making - * the person that came in third place (after the ties) would register as coming in fifth. + * sequence when there are ties. That is, if you were ranking a competition using dense_rank and + * had three people tie for second place, you would say that all three were in second place and + * that the next person came in third. Rank would give me sequential numbers, making the person + * that came in third place (after the ties) would register as coming in fifth. * * This is equivalent to the DENSE_RANK function in SQL. * @@ -1399,9 +1422,9 @@ object functions { def dense_rank(): Column = Column.fn("dense_rank") /** - * Window function: returns the value that is `offset` rows before the current row, and - * `null` if there is less than `offset` rows before the current row. For example, - * an `offset` of one will return the previous row at any given point in the window partition. + * Window function: returns the value that is `offset` rows before the current row, and `null` + * if there is less than `offset` rows before the current row. For example, an `offset` of one + * will return the previous row at any given point in the window partition. * * This is equivalent to the LAG function in SQL. * @@ -1411,9 +1434,9 @@ object functions { def lag(e: Column, offset: Int): Column = lag(e, offset, null) /** - * Window function: returns the value that is `offset` rows before the current row, and - * `null` if there is less than `offset` rows before the current row. For example, - * an `offset` of one will return the previous row at any given point in the window partition. + * Window function: returns the value that is `offset` rows before the current row, and `null` + * if there is less than `offset` rows before the current row. For example, an `offset` of one + * will return the previous row at any given point in the window partition. * * This is equivalent to the LAG function in SQL. * @@ -1424,8 +1447,8 @@ object functions { /** * Window function: returns the value that is `offset` rows before the current row, and - * `defaultValue` if there is less than `offset` rows before the current row. For example, - * an `offset` of one will return the previous row at any given point in the window partition. + * `defaultValue` if there is less than `offset` rows before the current row. For example, an + * `offset` of one will return the previous row at any given point in the window partition. * * This is equivalent to the LAG function in SQL. * @@ -1438,8 +1461,8 @@ object functions { /** * Window function: returns the value that is `offset` rows before the current row, and - * `defaultValue` if there is less than `offset` rows before the current row. For example, - * an `offset` of one will return the previous row at any given point in the window partition. + * `defaultValue` if there is less than `offset` rows before the current row. For example, an + * `offset` of one will return the previous row at any given point in the window partition. * * This is equivalent to the LAG function in SQL. * @@ -1453,9 +1476,9 @@ object functions { /** * Window function: returns the value that is `offset` rows before the current row, and * `defaultValue` if there is less than `offset` rows before the current row. `ignoreNulls` - * determines whether null values of row are included in or eliminated from the calculation. - * For example, an `offset` of one will return the previous row at any given point in the - * window partition. + * determines whether null values of row are included in or eliminated from the calculation. For + * example, an `offset` of one will return the previous row at any given point in the window + * partition. * * This is equivalent to the LAG function in SQL. * @@ -1466,9 +1489,9 @@ object functions { Column.fn("lag", false, e, lit(offset), lit(defaultValue), lit(ignoreNulls)) /** - * Window function: returns the value that is `offset` rows after the current row, and - * `null` if there is less than `offset` rows after the current row. For example, - * an `offset` of one will return the next row at any given point in the window partition. + * Window function: returns the value that is `offset` rows after the current row, and `null` if + * there is less than `offset` rows after the current row. For example, an `offset` of one will + * return the next row at any given point in the window partition. * * This is equivalent to the LEAD function in SQL. * @@ -1478,9 +1501,9 @@ object functions { def lead(columnName: String, offset: Int): Column = { lead(columnName, offset, null) } /** - * Window function: returns the value that is `offset` rows after the current row, and - * `null` if there is less than `offset` rows after the current row. For example, - * an `offset` of one will return the next row at any given point in the window partition. + * Window function: returns the value that is `offset` rows after the current row, and `null` if + * there is less than `offset` rows after the current row. For example, an `offset` of one will + * return the next row at any given point in the window partition. * * This is equivalent to the LEAD function in SQL. * @@ -1491,8 +1514,8 @@ object functions { /** * Window function: returns the value that is `offset` rows after the current row, and - * `defaultValue` if there is less than `offset` rows after the current row. For example, - * an `offset` of one will return the next row at any given point in the window partition. + * `defaultValue` if there is less than `offset` rows after the current row. For example, an + * `offset` of one will return the next row at any given point in the window partition. * * This is equivalent to the LEAD function in SQL. * @@ -1505,8 +1528,8 @@ object functions { /** * Window function: returns the value that is `offset` rows after the current row, and - * `defaultValue` if there is less than `offset` rows after the current row. For example, - * an `offset` of one will return the next row at any given point in the window partition. + * `defaultValue` if there is less than `offset` rows after the current row. For example, an + * `offset` of one will return the next row at any given point in the window partition. * * This is equivalent to the LEAD function in SQL. * @@ -1520,9 +1543,9 @@ object functions { /** * Window function: returns the value that is `offset` rows after the current row, and * `defaultValue` if there is less than `offset` rows after the current row. `ignoreNulls` - * determines whether null values of row are included in or eliminated from the calculation. - * The default value of `ignoreNulls` is false. For example, an `offset` of one will return - * the next row at any given point in the window partition. + * determines whether null values of row are included in or eliminated from the calculation. The + * default value of `ignoreNulls` is false. For example, an `offset` of one will return the next + * row at any given point in the window partition. * * This is equivalent to the LEAD function in SQL. * @@ -1533,11 +1556,11 @@ object functions { Column.fn("lead", false, e, lit(offset), lit(defaultValue), lit(ignoreNulls)) /** - * Window function: returns the value that is the `offset`th row of the window frame - * (counting from 1), and `null` if the size of window frame is less than `offset` rows. + * Window function: returns the value that is the `offset`th row of the window frame (counting + * from 1), and `null` if the size of window frame is less than `offset` rows. * - * It will return the `offset`th non-null value it sees when ignoreNulls is set to true. - * If all values are null, then null is returned. + * It will return the `offset`th non-null value it sees when ignoreNulls is set to true. If all + * values are null, then null is returned. * * This is equivalent to the nth_value function in SQL. * @@ -1548,8 +1571,8 @@ object functions { Column.fn("nth_value", false, e, lit(offset), lit(ignoreNulls)) /** - * Window function: returns the value that is the `offset`th row of the window frame - * (counting from 1), and `null` if the size of window frame is less than `offset` rows. + * Window function: returns the value that is the `offset`th row of the window frame (counting + * from 1), and `null` if the size of window frame is less than `offset` rows. * * This is equivalent to the nth_value function in SQL. * @@ -1560,8 +1583,8 @@ object functions { /** * Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window - * partition. For example, if `n` is 4, the first quarter of the rows will get value 1, the second - * quarter will get 2, the third quarter will get 3, and the last quarter will get 4. + * partition. For example, if `n` is 4, the first quarter of the rows will get value 1, the + * second quarter will get 2, the third quarter will get 3, and the last quarter will get 4. * * This is equivalent to the NTILE function in SQL. * @@ -1571,7 +1594,8 @@ object functions { def ntile(n: Int): Column = Column.fn("ntile", lit(n)) /** - * Window function: returns the relative rank (i.e. percentile) of rows within a window partition. + * Window function: returns the relative rank (i.e. percentile) of rows within a window + * partition. * * This is computed by: * {{{ @@ -1589,10 +1613,10 @@ object functions { * Window function: returns the rank of rows within a window partition. * * The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking - * sequence when there are ties. That is, if you were ranking a competition using dense_rank - * and had three people tie for second place, you would say that all three were in second - * place and that the next person came in third. Rank would give me sequential numbers, making - * the person that came in third place (after the ties) would register as coming in fifth. + * sequence when there are ties. That is, if you were ranking a competition using dense_rank and + * had three people tie for second place, you would say that all three were in second place and + * that the next person came in third. Rank would give me sequential numbers, making the person + * that came in third place (after the ties) would register as coming in fifth. * * This is equivalent to the RANK function in SQL. * @@ -1630,13 +1654,13 @@ object functions { */ @scala.annotation.varargs def array(colName: String, colNames: String*): Column = { - array((colName +: colNames).map(col) : _*) + array((colName +: colNames).map(col): _*) } /** - * Creates a new map column. The input columns must be grouped as key-value pairs, e.g. - * (key1, value1, key2, value2, ...). The key columns must all have the same data type, and can't - * be null. The value columns must all have the same data type. + * Creates a new map column. The input columns must be grouped as key-value pairs, e.g. (key1, + * value1, key2, value2, ...). The key columns must all have the same data type, and can't be + * null. The value columns must all have the same data type. * * @group map_funcs * @since 2.0 @@ -1663,8 +1687,8 @@ object functions { Column.fn("map_from_arrays", keys, values) /** - * Creates a map after splitting the text into key/value pairs using delimiters. - * Both `pairDelim` and `keyValueDelim` are treated as regular expressions. + * Creates a map after splitting the text into key/value pairs using delimiters. Both + * `pairDelim` and `keyValueDelim` are treated as regular expressions. * * @group map_funcs * @since 3.5.0 @@ -1673,8 +1697,8 @@ object functions { Column.fn("str_to_map", text, pairDelim, keyValueDelim) /** - * Creates a map after splitting the text into key/value pairs using delimiters. - * The `pairDelim` is treated as regular expressions. + * Creates a map after splitting the text into key/value pairs using delimiters. The `pairDelim` + * is treated as regular expressions. * * @group map_funcs * @since 3.5.0 @@ -1702,15 +1726,15 @@ object functions { * @group normal_funcs * @since 1.5.0 */ - def broadcast[DS[U] <: api.Dataset[U, DS]](df: DS[_]): df.type = { + def broadcast[DS[U] <: api.Dataset[U]](df: DS[_]): df.type = { df.hint("broadcast").asInstanceOf[df.type] } /** * Returns the first column that is not null, or null if all inputs are null. * - * For example, `coalesce(a, b, c)` will return a if a is not null, - * or b if a is null and b is not null, or c if both a and b are null but c is not null. + * For example, `coalesce(a, b, c)` will return a if a is not null, or b if a is null and b is + * not null, or c if both a and b are null but c is not null. * * @group conditional_funcs * @since 1.3.0 @@ -1745,13 +1769,13 @@ object functions { /** * A column expression that generates monotonically increasing 64-bit integers. * - * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. - * The current implementation puts the partition ID in the upper 31 bits, and the record number - * within each partition in the lower 33 bits. The assumption is that the data frame has - * less than 1 billion partitions, and each partition has less than 8 billion records. + * The generated ID is guaranteed to be monotonically increasing and unique, but not + * consecutive. The current implementation puts the partition ID in the upper 31 bits, and the + * record number within each partition in the lower 33 bits. The assumption is that the data + * frame has less than 1 billion partitions, and each partition has less than 8 billion records. * - * As an example, consider a `DataFrame` with two partitions, each with 3 records. - * This expression would return the following IDs: + * As an example, consider a `DataFrame` with two partitions, each with 3 records. This + * expression would return the following IDs: * * {{{ * 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. @@ -1766,13 +1790,13 @@ object functions { /** * A column expression that generates monotonically increasing 64-bit integers. * - * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. - * The current implementation puts the partition ID in the upper 31 bits, and the record number - * within each partition in the lower 33 bits. The assumption is that the data frame has - * less than 1 billion partitions, and each partition has less than 8 billion records. + * The generated ID is guaranteed to be monotonically increasing and unique, but not + * consecutive. The current implementation puts the partition ID in the upper 31 bits, and the + * record number within each partition in the lower 33 bits. The assumption is that the data + * frame has less than 1 billion partitions, and each partition has less than 8 billion records. * - * As an example, consider a `DataFrame` with two partitions, each with 3 records. - * This expression would return the following IDs: + * As an example, consider a `DataFrame` with two partitions, each with 3 records. This + * expression would return the following IDs: * * {{{ * 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. @@ -1828,7 +1852,8 @@ object functions { * Generate a random column with independent and identically distributed (i.i.d.) samples * uniformly distributed in [0.0, 1.0). * - * @note The function is non-deterministic in general case. + * @note + * The function is non-deterministic in general case. * * @group math_funcs * @since 1.4.0 @@ -1839,7 +1864,8 @@ object functions { * Generate a random column with independent and identically distributed (i.i.d.) samples * uniformly distributed in [0.0, 1.0). * - * @note The function is non-deterministic in general case. + * @note + * The function is non-deterministic in general case. * * @group math_funcs * @since 1.4.0 @@ -1847,10 +1873,11 @@ object functions { def rand(): Column = rand(SparkClassUtils.random.nextLong) /** - * Generate a column with independent and identically distributed (i.i.d.) samples from - * the standard normal distribution. + * Generate a column with independent and identically distributed (i.i.d.) samples from the + * standard normal distribution. * - * @note The function is non-deterministic in general case. + * @note + * The function is non-deterministic in general case. * * @group math_funcs * @since 1.4.0 @@ -1858,10 +1885,11 @@ object functions { def randn(seed: Long): Column = Column.fn("randn", lit(seed)) /** - * Generate a column with independent and identically distributed (i.i.d.) samples from - * the standard normal distribution. + * Generate a column with independent and identically distributed (i.i.d.) samples from the + * standard normal distribution. * - * @note The function is non-deterministic in general case. + * @note + * The function is non-deterministic in general case. * * @group math_funcs * @since 1.4.0 @@ -1871,7 +1899,8 @@ object functions { /** * Partition ID. * - * @note This is non-deterministic because it depends on data partitioning and task scheduling. + * @note + * This is non-deterministic because it depends on data partitioning and task scheduling. * * @group misc_funcs * @since 1.6.0 @@ -1921,8 +1950,7 @@ object functions { def try_divide(left: Column, right: Column): Column = Column.fn("try_divide", left, right) /** - * Returns the remainder of `dividend``/``divisor`. Its result is - * always null if `divisor` is 0. + * Returns the remainder of `dividend``/``divisor`. Its result is always null if `divisor` is 0. * * @group math_funcs * @since 4.0.0 @@ -1956,11 +1984,10 @@ object functions { def try_sum(e: Column): Column = Column.fn("try_sum", e) /** - * Creates a new struct column. - * If the input column is a column in a `DataFrame`, or a derived column expression - * that is named (i.e. aliased), its name would be retained as the StructField's name, - * otherwise, the newly generated StructField's name would be auto generated as - * `col` with a suffix `index + 1`, i.e. col1, col2, col3, ... + * Creates a new struct column. If the input column is a column in a `DataFrame`, or a derived + * column expression that is named (i.e. aliased), its name would be retained as the + * StructField's name, otherwise, the newly generated StructField's name would be auto generated + * as `col` with a suffix `index + 1`, i.e. col1, col2, col3, ... * * @group struct_funcs * @since 1.4.0 @@ -1976,12 +2003,12 @@ object functions { */ @scala.annotation.varargs def struct(colName: String, colNames: String*): Column = { - struct((colName +: colNames).map(col) : _*) + struct((colName +: colNames).map(col): _*) } /** - * Evaluates a list of conditions and returns one of multiple possible result expressions. - * If otherwise is not defined at the end, null is returned for unmatched conditions. + * Evaluates a list of conditions and returns one of multiple possible result expressions. If + * otherwise is not defined at the end, null is returned for unmatched conditions. * * {{{ * // Example: encoding gender string column into integer. @@ -2030,9 +2057,8 @@ object functions { def bit_count(e: Column): Column = Column.fn("bit_count", e) /** - * Returns the value of the bit (0 or 1) at the specified position. - * The positions are numbered from right to left, starting at zero. - * The position argument cannot be negative. + * Returns the value of the bit (0 or 1) at the specified position. The positions are numbered + * from right to left, starting at zero. The position argument cannot be negative. * * @group bitwise_funcs * @since 3.5.0 @@ -2040,9 +2066,8 @@ object functions { def bit_get(e: Column, pos: Column): Column = Column.fn("bit_get", e, pos) /** - * Returns the value of the bit (0 or 1) at the specified position. - * The positions are numbered from right to left, starting at zero. - * The position argument cannot be negative. + * Returns the value of the bit (0 or 1) at the specified position. The positions are numbered + * from right to left, starting at zero. The position argument cannot be negative. * * @group bitwise_funcs * @since 3.5.0 @@ -2074,7 +2099,8 @@ object functions { def abs(e: Column): Column = Column.fn("abs", e) /** - * @return inverse cosine of `e` in radians, as if computed by `java.lang.Math.acos` + * @return + * inverse cosine of `e` in radians, as if computed by `java.lang.Math.acos` * * @group math_funcs * @since 1.4.0 @@ -2082,7 +2108,8 @@ object functions { def acos(e: Column): Column = Column.fn("acos", e) /** - * @return inverse cosine of `columnName`, as if computed by `java.lang.Math.acos` + * @return + * inverse cosine of `columnName`, as if computed by `java.lang.Math.acos` * * @group math_funcs * @since 1.4.0 @@ -2090,7 +2117,8 @@ object functions { def acos(columnName: String): Column = acos(Column(columnName)) /** - * @return inverse hyperbolic cosine of `e` + * @return + * inverse hyperbolic cosine of `e` * * @group math_funcs * @since 3.1.0 @@ -2098,7 +2126,8 @@ object functions { def acosh(e: Column): Column = Column.fn("acosh", e) /** - * @return inverse hyperbolic cosine of `columnName` + * @return + * inverse hyperbolic cosine of `columnName` * * @group math_funcs * @since 3.1.0 @@ -2106,7 +2135,8 @@ object functions { def acosh(columnName: String): Column = acosh(Column(columnName)) /** - * @return inverse sine of `e` in radians, as if computed by `java.lang.Math.asin` + * @return + * inverse sine of `e` in radians, as if computed by `java.lang.Math.asin` * * @group math_funcs * @since 1.4.0 @@ -2114,7 +2144,8 @@ object functions { def asin(e: Column): Column = Column.fn("asin", e) /** - * @return inverse sine of `columnName`, as if computed by `java.lang.Math.asin` + * @return + * inverse sine of `columnName`, as if computed by `java.lang.Math.asin` * * @group math_funcs * @since 1.4.0 @@ -2122,7 +2153,8 @@ object functions { def asin(columnName: String): Column = asin(Column(columnName)) /** - * @return inverse hyperbolic sine of `e` + * @return + * inverse hyperbolic sine of `e` * * @group math_funcs * @since 3.1.0 @@ -2130,7 +2162,8 @@ object functions { def asinh(e: Column): Column = Column.fn("asinh", e) /** - * @return inverse hyperbolic sine of `columnName` + * @return + * inverse hyperbolic sine of `columnName` * * @group math_funcs * @since 3.1.0 @@ -2138,7 +2171,8 @@ object functions { def asinh(columnName: String): Column = asinh(Column(columnName)) /** - * @return inverse tangent of `e` as if computed by `java.lang.Math.atan` + * @return + * inverse tangent of `e` as if computed by `java.lang.Math.atan` * * @group math_funcs * @since 1.4.0 @@ -2146,7 +2180,8 @@ object functions { def atan(e: Column): Column = Column.fn("atan", e) /** - * @return inverse tangent of `columnName`, as if computed by `java.lang.Math.atan` + * @return + * inverse tangent of `columnName`, as if computed by `java.lang.Math.atan` * * @group math_funcs * @since 1.4.0 @@ -2154,13 +2189,14 @@ object functions { def atan(columnName: String): Column = atan(Column(columnName)) /** - * @param y coordinate on y-axis - * @param x coordinate on x-axis - * @return the theta component of the point - * (r, theta) - * in polar coordinates that corresponds to the point - * (x, y) in Cartesian coordinates, - * as if computed by `java.lang.Math.atan2` + * @param y + * coordinate on y-axis + * @param x + * coordinate on x-axis + * @return + * the theta component of the point (r, theta) in polar coordinates that + * corresponds to the point (x, y) in Cartesian coordinates, as if computed by + * `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 @@ -2168,13 +2204,14 @@ object functions { def atan2(y: Column, x: Column): Column = Column.fn("atan2", y, x) /** - * @param y coordinate on y-axis - * @param xName coordinate on x-axis - * @return the theta component of the point - * (r, theta) - * in polar coordinates that corresponds to the point - * (x, y) in Cartesian coordinates, - * as if computed by `java.lang.Math.atan2` + * @param y + * coordinate on y-axis + * @param xName + * coordinate on x-axis + * @return + * the theta component of the point (r, theta) in polar coordinates that + * corresponds to the point (x, y) in Cartesian coordinates, as if computed by + * `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 @@ -2182,13 +2219,14 @@ object functions { def atan2(y: Column, xName: String): Column = atan2(y, Column(xName)) /** - * @param yName coordinate on y-axis - * @param x coordinate on x-axis - * @return the theta component of the point - * (r, theta) - * in polar coordinates that corresponds to the point - * (x, y) in Cartesian coordinates, - * as if computed by `java.lang.Math.atan2` + * @param yName + * coordinate on y-axis + * @param x + * coordinate on x-axis + * @return + * the theta component of the point (r, theta) in polar coordinates that + * corresponds to the point (x, y) in Cartesian coordinates, as if computed by + * `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 @@ -2196,13 +2234,14 @@ object functions { def atan2(yName: String, x: Column): Column = atan2(Column(yName), x) /** - * @param yName coordinate on y-axis - * @param xName coordinate on x-axis - * @return the theta component of the point - * (r, theta) - * in polar coordinates that corresponds to the point - * (x, y) in Cartesian coordinates, - * as if computed by `java.lang.Math.atan2` + * @param yName + * coordinate on y-axis + * @param xName + * coordinate on x-axis + * @return + * the theta component of the point (r, theta) in polar coordinates that + * corresponds to the point (x, y) in Cartesian coordinates, as if computed by + * `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 @@ -2211,13 +2250,14 @@ object functions { atan2(Column(yName), Column(xName)) /** - * @param y coordinate on y-axis - * @param xValue coordinate on x-axis - * @return the theta component of the point - * (r, theta) - * in polar coordinates that corresponds to the point - * (x, y) in Cartesian coordinates, - * as if computed by `java.lang.Math.atan2` + * @param y + * coordinate on y-axis + * @param xValue + * coordinate on x-axis + * @return + * the theta component of the point (r, theta) in polar coordinates that + * corresponds to the point (x, y) in Cartesian coordinates, as if computed by + * `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 @@ -2225,13 +2265,14 @@ object functions { def atan2(y: Column, xValue: Double): Column = atan2(y, lit(xValue)) /** - * @param yName coordinate on y-axis - * @param xValue coordinate on x-axis - * @return the theta component of the point - * (r, theta) - * in polar coordinates that corresponds to the point - * (x, y) in Cartesian coordinates, - * as if computed by `java.lang.Math.atan2` + * @param yName + * coordinate on y-axis + * @param xValue + * coordinate on x-axis + * @return + * the theta component of the point (r, theta) in polar coordinates that + * corresponds to the point (x, y) in Cartesian coordinates, as if computed by + * `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 @@ -2239,13 +2280,14 @@ object functions { def atan2(yName: String, xValue: Double): Column = atan2(Column(yName), xValue) /** - * @param yValue coordinate on y-axis - * @param x coordinate on x-axis - * @return the theta component of the point - * (r, theta) - * in polar coordinates that corresponds to the point - * (x, y) in Cartesian coordinates, - * as if computed by `java.lang.Math.atan2` + * @param yValue + * coordinate on y-axis + * @param x + * coordinate on x-axis + * @return + * the theta component of the point (r, theta) in polar coordinates that + * corresponds to the point (x, y) in Cartesian coordinates, as if computed by + * `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 @@ -2253,13 +2295,14 @@ object functions { def atan2(yValue: Double, x: Column): Column = atan2(lit(yValue), x) /** - * @param yValue coordinate on y-axis - * @param xName coordinate on x-axis - * @return the theta component of the point - * (r, theta) - * in polar coordinates that corresponds to the point - * (x, y) in Cartesian coordinates, - * as if computed by `java.lang.Math.atan2` + * @param yValue + * coordinate on y-axis + * @param xName + * coordinate on x-axis + * @return + * the theta component of the point (r, theta) in polar coordinates that + * corresponds to the point (x, y) in Cartesian coordinates, as if computed by + * `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 @@ -2267,7 +2310,8 @@ object functions { def atan2(yValue: Double, xName: String): Column = atan2(yValue, Column(xName)) /** - * @return inverse hyperbolic tangent of `e` + * @return + * inverse hyperbolic tangent of `e` * * @group math_funcs * @since 3.1.0 @@ -2275,7 +2319,8 @@ object functions { def atanh(e: Column): Column = Column.fn("atanh", e) /** - * @return inverse hyperbolic tangent of `columnName` + * @return + * inverse hyperbolic tangent of `columnName` * * @group math_funcs * @since 3.1.0 @@ -2366,8 +2411,10 @@ object functions { Column.fn("conv", num, lit(fromBase), lit(toBase)) /** - * @param e angle in radians - * @return cosine of the angle, as if computed by `java.lang.Math.cos` + * @param e + * angle in radians + * @return + * cosine of the angle, as if computed by `java.lang.Math.cos` * * @group math_funcs * @since 1.4.0 @@ -2375,8 +2422,10 @@ object functions { def cos(e: Column): Column = Column.fn("cos", e) /** - * @param columnName angle in radians - * @return cosine of the angle, as if computed by `java.lang.Math.cos` + * @param columnName + * angle in radians + * @return + * cosine of the angle, as if computed by `java.lang.Math.cos` * * @group math_funcs * @since 1.4.0 @@ -2384,8 +2433,10 @@ object functions { def cos(columnName: String): Column = cos(Column(columnName)) /** - * @param e hyperbolic angle - * @return hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh` + * @param e + * hyperbolic angle + * @return + * hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh` * * @group math_funcs * @since 1.4.0 @@ -2393,8 +2444,10 @@ object functions { def cosh(e: Column): Column = Column.fn("cosh", e) /** - * @param columnName hyperbolic angle - * @return hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh` + * @param columnName + * hyperbolic angle + * @return + * hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh` * * @group math_funcs * @since 1.4.0 @@ -2402,8 +2455,10 @@ object functions { def cosh(columnName: String): Column = cosh(Column(columnName)) /** - * @param e angle in radians - * @return cotangent of the angle + * @param e + * angle in radians + * @return + * cotangent of the angle * * @group math_funcs * @since 3.3.0 @@ -2411,8 +2466,10 @@ object functions { def cot(e: Column): Column = Column.fn("cot", e) /** - * @param e angle in radians - * @return cosecant of the angle + * @param e + * angle in radians + * @return + * cosecant of the angle * * @group math_funcs * @since 3.3.0 @@ -2492,8 +2549,8 @@ object functions { def floor(columnName: String): Column = floor(Column(columnName)) /** - * Returns the greatest value of the list of values, skipping null values. - * This function takes at least 2 parameters. It will return null iff all parameters are null. + * Returns the greatest value of the list of values, skipping null values. This function takes + * at least 2 parameters. It will return null iff all parameters are null. * * @group math_funcs * @since 1.5.0 @@ -2502,8 +2559,8 @@ object functions { def greatest(exprs: Column*): Column = Column.fn("greatest", exprs: _*) /** - * Returns the greatest value of the list of column names, skipping null values. - * This function takes at least 2 parameters. It will return null iff all parameters are null. + * Returns the greatest value of the list of column names, skipping null values. This function + * takes at least 2 parameters. It will return null iff all parameters are null. * * @group math_funcs * @since 1.5.0 @@ -2522,8 +2579,8 @@ object functions { def hex(column: Column): Column = Column.fn("hex", column) /** - * Inverse of hex. Interprets each pair of characters as a hexadecimal number - * and converts to the byte representation of number. + * Inverse of hex. Interprets each pair of characters as a hexadecimal number and converts to + * the byte representation of number. * * @group math_funcs * @since 1.5.0 @@ -2596,8 +2653,8 @@ object functions { def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName)) /** - * Returns the least value of the list of values, skipping null values. - * This function takes at least 2 parameters. It will return null iff all parameters are null. + * Returns the least value of the list of values, skipping null values. This function takes at + * least 2 parameters. It will return null iff all parameters are null. * * @group math_funcs * @since 1.5.0 @@ -2606,8 +2663,8 @@ object functions { def least(exprs: Column*): Column = Column.fn("least", exprs: _*) /** - * Returns the least value of the list of column names, skipping null values. - * This function takes at least 2 parameters. It will return null iff all parameters are null. + * Returns the least value of the list of column names, skipping null values. This function + * takes at least 2 parameters. It will return null iff all parameters are null. * * @group math_funcs * @since 1.5.0 @@ -2810,8 +2867,8 @@ object functions { def pmod(dividend: Column, divisor: Column): Column = Column.fn("pmod", dividend, divisor) /** - * Returns the double value that is closest in value to the argument and - * is equal to a mathematical integer. + * Returns the double value that is closest in value to the argument and is equal to a + * mathematical integer. * * @group math_funcs * @since 1.4.0 @@ -2819,8 +2876,8 @@ object functions { def rint(e: Column): Column = Column.fn("rint", e) /** - * Returns the double value that is closest in value to the argument and - * is equal to a mathematical integer. + * Returns the double value that is closest in value to the argument and is equal to a + * mathematical integer. * * @group math_funcs * @since 1.4.0 @@ -2836,8 +2893,8 @@ object functions { def round(e: Column): Column = round(e, 0) /** - * Round the value of `e` to `scale` decimal places with HALF_UP round mode - * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0. + * Round the value of `e` to `scale` decimal places with HALF_UP round mode if `scale` is + * greater than or equal to 0 or at integral part when `scale` is less than 0. * * @group math_funcs * @since 1.5.0 @@ -2845,8 +2902,8 @@ object functions { def round(e: Column, scale: Int): Column = Column.fn("round", e, lit(scale)) /** - * Round the value of `e` to `scale` decimal places with HALF_UP round mode - * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0. + * Round the value of `e` to `scale` decimal places with HALF_UP round mode if `scale` is + * greater than or equal to 0 or at integral part when `scale` is less than 0. * * @group math_funcs * @since 4.0.0 @@ -2862,8 +2919,8 @@ object functions { def bround(e: Column): Column = bround(e, 0) /** - * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode - * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0. + * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode if `scale` is + * greater than or equal to 0 or at integral part when `scale` is less than 0. * * @group math_funcs * @since 2.0.0 @@ -2871,8 +2928,8 @@ object functions { def bround(e: Column, scale: Int): Column = Column.fn("bround", e, lit(scale)) /** - * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode - * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0. + * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode if `scale` is + * greater than or equal to 0 or at integral part when `scale` is less than 0. * * @group math_funcs * @since 4.0.0 @@ -2880,8 +2937,10 @@ object functions { def bround(e: Column, scale: Column): Column = Column.fn("bround", e, scale) /** - * @param e angle in radians - * @return secant of the angle + * @param e + * angle in radians + * @return + * secant of the angle * * @group math_funcs * @since 3.3.0 @@ -2889,8 +2948,8 @@ object functions { def sec(e: Column): Column = Column.fn("sec", e) /** - * Shift the given value numBits left. If the given value is a long value, this function - * will return a long value else it will return an integer value. + * Shift the given value numBits left. If the given value is a long value, this function will + * return a long value else it will return an integer value. * * @group bitwise_funcs * @since 1.5.0 @@ -2899,8 +2958,8 @@ object functions { def shiftLeft(e: Column, numBits: Int): Column = shiftleft(e, numBits) /** - * Shift the given value numBits left. If the given value is a long value, this function - * will return a long value else it will return an integer value. + * Shift the given value numBits left. If the given value is a long value, this function will + * return a long value else it will return an integer value. * * @group bitwise_funcs * @since 3.2.0 @@ -2927,8 +2986,8 @@ object functions { def shiftright(e: Column, numBits: Int): Column = Column.fn("shiftright", e, lit(numBits)) /** - * Unsigned shift the given value numBits right. If the given value is a long value, - * it will return a long value else it will return an integer value. + * Unsigned shift the given value numBits right. If the given value is a long value, it will + * return a long value else it will return an integer value. * * @group bitwise_funcs * @since 1.5.0 @@ -2937,8 +2996,8 @@ object functions { def shiftRightUnsigned(e: Column, numBits: Int): Column = shiftrightunsigned(e, numBits) /** - * Unsigned shift the given value numBits right. If the given value is a long value, - * it will return a long value else it will return an integer value. + * Unsigned shift the given value numBits right. If the given value is a long value, it will + * return a long value else it will return an integer value. * * @group bitwise_funcs * @since 3.2.0 @@ -2971,8 +3030,10 @@ object functions { def signum(columnName: String): Column = signum(Column(columnName)) /** - * @param e angle in radians - * @return sine of the angle, as if computed by `java.lang.Math.sin` + * @param e + * angle in radians + * @return + * sine of the angle, as if computed by `java.lang.Math.sin` * * @group math_funcs * @since 1.4.0 @@ -2980,8 +3041,10 @@ object functions { def sin(e: Column): Column = Column.fn("sin", e) /** - * @param columnName angle in radians - * @return sine of the angle, as if computed by `java.lang.Math.sin` + * @param columnName + * angle in radians + * @return + * sine of the angle, as if computed by `java.lang.Math.sin` * * @group math_funcs * @since 1.4.0 @@ -2989,8 +3052,10 @@ object functions { def sin(columnName: String): Column = sin(Column(columnName)) /** - * @param e hyperbolic angle - * @return hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh` + * @param e + * hyperbolic angle + * @return + * hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh` * * @group math_funcs * @since 1.4.0 @@ -2998,8 +3063,10 @@ object functions { def sinh(e: Column): Column = Column.fn("sinh", e) /** - * @param columnName hyperbolic angle - * @return hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh` + * @param columnName + * hyperbolic angle + * @return + * hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh` * * @group math_funcs * @since 1.4.0 @@ -3007,8 +3074,10 @@ object functions { def sinh(columnName: String): Column = sinh(Column(columnName)) /** - * @param e angle in radians - * @return tangent of the given value, as if computed by `java.lang.Math.tan` + * @param e + * angle in radians + * @return + * tangent of the given value, as if computed by `java.lang.Math.tan` * * @group math_funcs * @since 1.4.0 @@ -3016,8 +3085,10 @@ object functions { def tan(e: Column): Column = Column.fn("tan", e) /** - * @param columnName angle in radians - * @return tangent of the given value, as if computed by `java.lang.Math.tan` + * @param columnName + * angle in radians + * @return + * tangent of the given value, as if computed by `java.lang.Math.tan` * * @group math_funcs * @since 1.4.0 @@ -3025,8 +3096,10 @@ object functions { def tan(columnName: String): Column = tan(Column(columnName)) /** - * @param e hyperbolic angle - * @return hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh` + * @param e + * hyperbolic angle + * @return + * hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh` * * @group math_funcs * @since 1.4.0 @@ -3034,8 +3107,10 @@ object functions { def tanh(e: Column): Column = Column.fn("tanh", e) /** - * @param columnName hyperbolic angle - * @return hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh` + * @param columnName + * hyperbolic angle + * @return + * hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh` * * @group math_funcs * @since 1.4.0 @@ -3057,10 +3132,13 @@ object functions { def toDegrees(columnName: String): Column = degrees(Column(columnName)) /** - * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + * Converts an angle measured in radians to an approximately equivalent angle measured in + * degrees. * - * @param e angle in radians - * @return angle in degrees, as if computed by `java.lang.Math.toDegrees` + * @param e + * angle in radians + * @return + * angle in degrees, as if computed by `java.lang.Math.toDegrees` * * @group math_funcs * @since 2.1.0 @@ -3068,10 +3146,13 @@ object functions { def degrees(e: Column): Column = Column.fn("degrees", e) /** - * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + * Converts an angle measured in radians to an approximately equivalent angle measured in + * degrees. * - * @param columnName angle in radians - * @return angle in degrees, as if computed by `java.lang.Math.toDegrees` + * @param columnName + * angle in radians + * @return + * angle in degrees, as if computed by `java.lang.Math.toDegrees` * * @group math_funcs * @since 2.1.0 @@ -3093,10 +3174,13 @@ object functions { def toRadians(columnName: String): Column = radians(Column(columnName)) /** - * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. + * Converts an angle measured in degrees to an approximately equivalent angle measured in + * radians. * - * @param e angle in degrees - * @return angle in radians, as if computed by `java.lang.Math.toRadians` + * @param e + * angle in degrees + * @return + * angle in radians, as if computed by `java.lang.Math.toRadians` * * @group math_funcs * @since 2.1.0 @@ -3104,10 +3188,13 @@ object functions { def radians(e: Column): Column = Column.fn("radians", e) /** - * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. + * Converts an angle measured in degrees to an approximately equivalent angle measured in + * radians. * - * @param columnName angle in degrees - * @return angle in radians, as if computed by `java.lang.Math.toRadians` + * @param columnName + * angle in degrees + * @return + * angle in radians, as if computed by `java.lang.Math.toRadians` * * @group math_funcs * @since 2.1.0 @@ -3115,15 +3202,20 @@ object functions { def radians(columnName: String): Column = radians(Column(columnName)) /** - * Returns the bucket number into which the value of this expression would fall - * after being evaluated. Note that input arguments must follow conditions listed below; - * otherwise, the method will return null. + * Returns the bucket number into which the value of this expression would fall after being + * evaluated. Note that input arguments must follow conditions listed below; otherwise, the + * method will return null. * - * @param v value to compute a bucket number in the histogram - * @param min minimum value of the histogram - * @param max maximum value of the histogram - * @param numBucket the number of buckets - * @return the bucket number into which the value would fall after being evaluated + * @param v + * value to compute a bucket number in the histogram + * @param min + * minimum value of the histogram + * @param max + * maximum value of the histogram + * @param numBucket + * the number of buckets + * @return + * the bucket number into which the value would fall after being evaluated * @group math_funcs * @since 3.5.0 */ @@ -3167,8 +3259,8 @@ object functions { def current_user(): Column = Column.fn("current_user") /** - * Calculates the MD5 digest of a binary column and returns the value - * as a 32 character hex string. + * Calculates the MD5 digest of a binary column and returns the value as a 32 character hex + * string. * * @group hash_funcs * @since 1.5.0 @@ -3176,8 +3268,8 @@ object functions { def md5(e: Column): Column = Column.fn("md5", e) /** - * Calculates the SHA-1 digest of a binary column and returns the value - * as a 40 character hex string. + * Calculates the SHA-1 digest of a binary column and returns the value as a 40 character hex + * string. * * @group hash_funcs * @since 1.5.0 @@ -3185,11 +3277,13 @@ object functions { def sha1(e: Column): Column = Column.fn("sha1", e) /** - * Calculates the SHA-2 family of hash functions of a binary column and - * returns the value as a hex string. + * Calculates the SHA-2 family of hash functions of a binary column and returns the value as a + * hex string. * - * @param e column to compute SHA-2 on. - * @param numBits one of 224, 256, 384, or 512. + * @param e + * column to compute SHA-2 on. + * @param numBits + * one of 224, 256, 384, or 512. * * @group hash_funcs * @since 1.5.0 @@ -3202,8 +3296,8 @@ object functions { } /** - * Calculates the cyclic redundancy check value (CRC32) of a binary column and - * returns the value as a bigint. + * Calculates the cyclic redundancy check value (CRC32) of a binary column and returns the value + * as a bigint. * * @group hash_funcs * @since 1.5.0 @@ -3220,9 +3314,8 @@ object functions { def hash(cols: Column*): Column = Column.fn("hash", cols: _*) /** - * Calculates the hash code of given columns using the 64-bit - * variant of the xxHash algorithm, and returns the result as a long - * column. The hash computation uses an initial seed of 42. + * Calculates the hash code of given columns using the 64-bit variant of the xxHash algorithm, + * and returns the result as a long column. The hash computation uses an initial seed of 42. * * @group hash_funcs * @since 3.0.0 @@ -3255,8 +3348,8 @@ object functions { def raise_error(c: Column): Column = Column.fn("raise_error", c) /** - * Returns the estimated number of unique values given the binary representation - * of a Datasketches HllSketch. + * Returns the estimated number of unique values given the binary representation of a + * Datasketches HllSketch. * * @group misc_funcs * @since 3.5.0 @@ -3264,8 +3357,8 @@ object functions { def hll_sketch_estimate(c: Column): Column = Column.fn("hll_sketch_estimate", c) /** - * Returns the estimated number of unique values given the binary representation - * of a Datasketches HllSketch. + * Returns the estimated number of unique values given the binary representation of a + * Datasketches HllSketch. * * @group misc_funcs * @since 3.5.0 @@ -3275,9 +3368,8 @@ object functions { } /** - * Merges two binary representations of Datasketches HllSketch objects, using a - * Datasketches Union object. Throws an exception if sketches have different - * lgConfigK values. + * Merges two binary representations of Datasketches HllSketch objects, using a Datasketches + * Union object. Throws an exception if sketches have different lgConfigK values. * * @group misc_funcs * @since 3.5.0 @@ -3286,9 +3378,8 @@ object functions { Column.fn("hll_union", c1, c2) /** - * Merges two binary representations of Datasketches HllSketch objects, using a - * Datasketches Union object. Throws an exception if sketches have different - * lgConfigK values. + * Merges two binary representations of Datasketches HllSketch objects, using a Datasketches + * Union object. Throws an exception if sketches have different lgConfigK values. * * @group misc_funcs * @since 3.5.0 @@ -3298,9 +3389,9 @@ object functions { } /** - * Merges two binary representations of Datasketches HllSketch objects, using a - * Datasketches Union object. Throws an exception if sketches have different - * lgConfigK values and allowDifferentLgConfigK is set to false. + * Merges two binary representations of Datasketches HllSketch objects, using a Datasketches + * Union object. Throws an exception if sketches have different lgConfigK values and + * allowDifferentLgConfigK is set to false. * * @group misc_funcs * @since 3.5.0 @@ -3309,15 +3400,17 @@ object functions { Column.fn("hll_union", c1, c2, lit(allowDifferentLgConfigK)) /** - * Merges two binary representations of Datasketches HllSketch objects, using a - * Datasketches Union object. Throws an exception if sketches have different - * lgConfigK values and allowDifferentLgConfigK is set to false. + * Merges two binary representations of Datasketches HllSketch objects, using a Datasketches + * Union object. Throws an exception if sketches have different lgConfigK values and + * allowDifferentLgConfigK is set to false. * * @group misc_funcs * @since 3.5.0 */ - def hll_union(columnName1: String, columnName2: String, allowDifferentLgConfigK: Boolean): - Column = { + def hll_union( + columnName1: String, + columnName2: String, + allowDifferentLgConfigK: Boolean): Column = { hll_union(Column(columnName1), Column(columnName2), allowDifferentLgConfigK) } @@ -3684,8 +3777,8 @@ object functions { Column.fn("bitmap_bucket_number", col) /** - * Returns a bitmap with the positions of the bits set from all the values from the input column. - * The input column will most likely be bitmap_bit_position(). + * Returns a bitmap with the positions of the bits set from all the values from the input + * column. The input column will most likely be bitmap_bit_position(). * * @group agg_funcs * @since 3.5.0 @@ -3702,8 +3795,8 @@ object functions { def bitmap_count(col: Column): Column = Column.fn("bitmap_count", col) /** - * Returns a bitmap that is the bitwise OR of all of the bitmaps from the input column. - * The input column should be bitmaps created from bitmap_construct_agg(). + * Returns a bitmap that is the bitwise OR of all of the bitmaps from the input column. The + * input column should be bitmaps created from bitmap_construct_agg(). * * @group agg_funcs * @since 3.5.0 @@ -3724,8 +3817,8 @@ object functions { def ascii(e: Column): Column = Column.fn("ascii", e) /** - * Computes the BASE64 encoding of a binary column and returns it as a string column. - * This is the reverse of unbase64. + * Computes the BASE64 encoding of a binary column and returns it as a string column. This is + * the reverse of unbase64. * * @group string_funcs * @since 1.5.0 @@ -3741,10 +3834,11 @@ object functions { def bit_length(e: Column): Column = Column.fn("bit_length", e) /** - * Concatenates multiple input string columns together into a single string column, - * using the given separator. + * Concatenates multiple input string columns together into a single string column, using the + * given separator. * - * @note Input strings which are null are skipped. + * @note + * Input strings which are null are skipped. * * @group string_funcs * @since 1.5.0 @@ -3754,9 +3848,9 @@ object functions { Column.fn("concat_ws", lit(sep) +: exprs: _*) /** - * Computes the first argument into a string from a binary using the provided character set - * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16', 'UTF-32'). - * If either argument is null, the result will also be null. + * Computes the first argument into a string from a binary using the provided character set (one + * of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16', 'UTF-32'). If either + * argument is null, the result will also be null. * * @group string_funcs * @since 1.5.0 @@ -3765,9 +3859,9 @@ object functions { Column.fn("decode", value, lit(charset)) /** - * Computes the first argument into a binary from a string using the provided character set - * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16', 'UTF-32'). - * If either argument is null, the result will also be null. + * Computes the first argument into a binary from a string using the provided character set (one + * of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16', 'UTF-32'). If either + * argument is null, the result will also be null. * * @group string_funcs * @since 1.5.0 @@ -3776,11 +3870,11 @@ object functions { Column.fn("encode", value, lit(charset)) /** - * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places - * with HALF_EVEN round mode, and returns the result as a string column. + * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places with + * HALF_EVEN round mode, and returns the result as a string column. * - * If d is 0, the result has no decimal point or fractional part. - * If d is less than 0, the result will be null. + * If d is 0, the result has no decimal point or fractional part. If d is less than 0, the + * result will be null. * * @group string_funcs * @since 1.5.0 @@ -3798,8 +3892,8 @@ object functions { Column.fn("format_string", lit(format) +: arguments: _*) /** - * Returns a new string column by converting the first letter of each word to uppercase. - * Words are delimited by whitespace. + * Returns a new string column by converting the first letter of each word to uppercase. Words + * are delimited by whitespace. * * For example, "hello world" will become "Hello World". * @@ -3809,11 +3903,12 @@ object functions { def initcap(e: Column): Column = Column.fn("initcap", e) /** - * Locate the position of the first occurrence of substr column in the given string. - * Returns null if either of the arguments are null. + * Locate the position of the first occurrence of substr column in the given string. Returns + * null if either of the arguments are null. * - * @note The position is not zero based, but 1 based index. Returns 0 if substr - * could not be found in str. + * @note + * The position is not zero based, but 1 based index. Returns 0 if substr could not be found + * in str. * * @group string_funcs * @since 1.5.0 @@ -3821,8 +3916,8 @@ object functions { def instr(str: Column, substring: String): Column = Column.fn("instr", str, lit(substring)) /** - * Computes the character length of a given string or number of bytes of a binary string. - * The length of character strings include the trailing spaces. The length of binary strings + * Computes the character length of a given string or number of bytes of a binary string. The + * length of character strings include the trailing spaces. The length of binary strings * includes binary zeros. * * @group string_funcs @@ -3831,8 +3926,8 @@ object functions { def length(e: Column): Column = Column.fn("length", e) /** - * Computes the character length of a given string or number of bytes of a binary string. - * The length of character strings include the trailing spaces. The length of binary strings + * Computes the character length of a given string or number of bytes of a binary string. The + * length of character strings include the trailing spaces. The length of binary strings * includes binary zeros. * * @group string_funcs @@ -3849,9 +3944,10 @@ object functions { def lower(e: Column): Column = Column.fn("lower", e) /** - * Computes the Levenshtein distance of the two given string columns if it's less than or - * equal to a given threshold. - * @return result distance, or -1 + * Computes the Levenshtein distance of the two given string columns if it's less than or equal + * to a given threshold. + * @return + * result distance, or -1 * @group string_funcs * @since 3.5.0 */ @@ -3868,8 +3964,9 @@ object functions { /** * Locate the position of the first occurrence of substr. * - * @note The position is not zero based, but 1 based index. Returns 0 if substr - * could not be found in str. + * @note + * The position is not zero based, but 1 based index. Returns 0 if substr could not be found + * in str. * * @group string_funcs * @since 1.5.0 @@ -3879,8 +3976,9 @@ object functions { /** * Locate the position of the first occurrence of substr in a string column, after position pos. * - * @note The position is not zero based, but 1 based index. returns 0 if substr - * could not be found in str. + * @note + * The position is not zero based, but 1 based index. returns 0 if substr could not be found + * in str. * * @group string_funcs * @since 1.5.0 @@ -3889,8 +3987,8 @@ object functions { Column.fn("locate", lit(substr), str, lit(pos)) /** - * Left-pad the string column with pad to a length of len. If the string column is longer - * than len, the return value is shortened to len characters. + * Left-pad the string column with pad to a length of len. If the string column is longer than + * len, the return value is shortened to len characters. * * @group string_funcs * @since 1.5.0 @@ -3972,8 +4070,8 @@ object functions { def regexp_like(str: Column, regexp: Column): Column = Column.fn("regexp_like", str, regexp) /** - * Returns a count of the number of times that the regular expression pattern `regexp` - * is matched in the string `str`. + * Returns a count of the number of times that the regular expression pattern `regexp` is + * matched in the string `str`. * * @group string_funcs * @since 3.5.0 @@ -3981,10 +4079,10 @@ object functions { def regexp_count(str: Column, regexp: Column): Column = Column.fn("regexp_count", str, regexp) /** - * Extract a specific group matched by a Java regex, from the specified string column. - * If the regex did not match, or the specified group did not match, an empty string is returned. - * if the specified group index exceeds the group count of regex, an IllegalArgumentException - * will be thrown. + * Extract a specific group matched by a Java regex, from the specified string column. If the + * regex did not match, or the specified group did not match, an empty string is returned. if + * the specified group index exceeds the group count of regex, an IllegalArgumentException will + * be thrown. * * @group string_funcs * @since 1.5.0 @@ -3993,8 +4091,8 @@ object functions { Column.fn("regexp_extract", e, lit(exp), lit(groupIdx)) /** - * Extract all strings in the `str` that match the `regexp` expression and - * corresponding to the first regex group index. + * Extract all strings in the `str` that match the `regexp` expression and corresponding to the + * first regex group index. * * @group string_funcs * @since 3.5.0 @@ -4003,8 +4101,8 @@ object functions { Column.fn("regexp_extract_all", str, regexp) /** - * Extract all strings in the `str` that match the `regexp` expression and - * corresponding to the regex group index. + * Extract all strings in the `str` that match the `regexp` expression and corresponding to the + * regex group index. * * @group string_funcs * @since 3.5.0 @@ -4040,9 +4138,9 @@ object functions { def regexp_substr(str: Column, regexp: Column): Column = Column.fn("regexp_substr", str, regexp) /** - * Searches a string for a regular expression and returns an integer that indicates - * the beginning position of the matched substring. Positions are 1-based, not 0-based. - * If no match is found, returns 0. + * Searches a string for a regular expression and returns an integer that indicates the + * beginning position of the matched substring. Positions are 1-based, not 0-based. If no match + * is found, returns 0. * * @group string_funcs * @since 3.5.0 @@ -4050,9 +4148,9 @@ object functions { def regexp_instr(str: Column, regexp: Column): Column = Column.fn("regexp_instr", str, regexp) /** - * Searches a string for a regular expression and returns an integer that indicates - * the beginning position of the matched substring. Positions are 1-based, not 0-based. - * If no match is found, returns 0. + * Searches a string for a regular expression and returns an integer that indicates the + * beginning position of the matched substring. Positions are 1-based, not 0-based. If no match + * is found, returns 0. * * @group string_funcs * @since 3.5.0 @@ -4061,8 +4159,8 @@ object functions { Column.fn("regexp_instr", str, regexp, idx) /** - * Decodes a BASE64 encoded string column and returns it as a binary column. - * This is the reverse of base64. + * Decodes a BASE64 encoded string column and returns it as a binary column. This is the reverse + * of base64. * * @group string_funcs * @since 1.5.0 @@ -4070,8 +4168,8 @@ object functions { def unbase64(e: Column): Column = Column.fn("unbase64", e) /** - * Right-pad the string column with pad to a length of len. If the string column is longer - * than len, the return value is shortened to len characters. + * Right-pad the string column with pad to a length of len. If the string column is longer than + * len, the return value is shortened to len characters. * * @group string_funcs * @since 1.5.0 @@ -4199,11 +4297,11 @@ object functions { Column.fn("split", str, pattern, limit) /** - * Substring starts at `pos` and is of length `len` when str is String type or - * returns the slice of byte array that starts at `pos` in byte and is of length `len` - * when str is Binary type + * Substring starts at `pos` and is of length `len` when str is String type or returns the slice + * of byte array that starts at `pos` in byte and is of length `len` when str is Binary type * - * @note The position is not zero based, but 1 based index. + * @note + * The position is not zero based, but 1 based index. * * @group string_funcs * @since 1.5.0 @@ -4212,11 +4310,11 @@ object functions { Column.fn("substring", str, lit(pos), lit(len)) /** - * Substring starts at `pos` and is of length `len` when str is String type or - * returns the slice of byte array that starts at `pos` in byte and is of length `len` - * when str is Binary type + * Substring starts at `pos` and is of length `len` when str is String type or returns the slice + * of byte array that starts at `pos` in byte and is of length `len` when str is Binary type * - * @note The position is not zero based, but 1 based index. + * @note + * The position is not zero based, but 1 based index. * * @group string_funcs * @since 4.0.0 @@ -4225,8 +4323,8 @@ object functions { Column.fn("substring", str, pos, len) /** - * Returns the substring from string str before count occurrences of the delimiter delim. - * If count is positive, everything the left of the final delimiter (counting from left) is + * Returns the substring from string str before count occurrences of the delimiter delim. If + * count is positive, everything the left of the final delimiter (counting from left) is * returned. If count is negative, every to the right of the final delimiter (counting from the * right) is returned. substring_index performs a case-sensitive match when searching for delim. * @@ -4236,8 +4334,8 @@ object functions { Column.fn("substring_index", str, lit(delim), lit(count)) /** - * Overlay the specified portion of `src` with `replace`, - * starting from byte position `pos` of `src` and proceeding for `len` bytes. + * Overlay the specified portion of `src` with `replace`, starting from byte position `pos` of + * `src` and proceeding for `len` bytes. * * @group string_funcs * @since 3.0.0 @@ -4246,8 +4344,8 @@ object functions { Column.fn("overlay", src, replace, pos, len) /** - * Overlay the specified portion of `src` with `replace`, - * starting from byte position `pos` of `src`. + * Overlay the specified portion of `src` with `replace`, starting from byte position `pos` of + * `src`. * * @group string_funcs * @since 3.0.0 @@ -4264,18 +4362,26 @@ object functions { Column.fn("sentences", string, language, country) /** - * Splits a string into arrays of sentences, where each sentence is an array of words. - * The default locale is used. + * Splits a string into arrays of sentences, where each sentence is an array of words. The + * default `country`('') is used. + * @group string_funcs + * @since 4.0.0 + */ + def sentences(string: Column, language: Column): Column = + Column.fn("sentences", string, language) + + /** + * Splits a string into arrays of sentences, where each sentence is an array of words. The + * default locale is used. * @group string_funcs * @since 3.2.0 */ def sentences(string: Column): Column = Column.fn("sentences", string) /** - * Translate any character in the src by a character in replaceString. - * The characters in replaceString correspond to the characters in matchingString. - * The translate will happen when any character in the string matches the character - * in the `matchingString`. + * Translate any character in the src by a character in replaceString. The characters in + * replaceString correspond to the characters in matchingString. The translate will happen when + * any character in the string matches the character in the `matchingString`. * * @group string_funcs * @since 1.5.0 @@ -4307,10 +4413,10 @@ object functions { def upper(e: Column): Column = Column.fn("upper", e) /** - * Converts the input `e` to a binary value based on the supplied `format`. - * The `format` can be a case-insensitive string literal of "hex", "utf-8", "utf8", or "base64". - * By default, the binary format for conversion is "hex" if `format` is omitted. - * The function returns NULL if at least one of the input parameters is NULL. + * Converts the input `e` to a binary value based on the supplied `format`. The `format` can be + * a case-insensitive string literal of "hex", "utf-8", "utf8", or "base64". By default, the + * binary format for conversion is "hex" if `format` is omitted. The function returns NULL if at + * least one of the input parameters is NULL. * * @group string_funcs * @since 3.5.0 @@ -4318,8 +4424,8 @@ object functions { def to_binary(e: Column, f: Column): Column = Column.fn("to_binary", e, f) /** - * Converts the input `e` to a binary value based on the default format "hex". - * The function returns NULL if at least one of the input parameters is NULL. + * Converts the input `e` to a binary value based on the default format "hex". The function + * returns NULL if at least one of the input parameters is NULL. * * @group string_funcs * @since 3.5.0 @@ -4328,32 +4434,27 @@ object functions { // scalastyle:off line.size.limit /** - * Convert `e` to a string based on the `format`. - * Throws an exception if the conversion fails. The format can consist of the following - * characters, case insensitive: - * '0' or '9': Specifies an expected digit between 0 and 9. A sequence of 0 or 9 in the format - * string matches a sequence of digits in the input value, generating a result string of the - * same length as the corresponding sequence in the format string. The result string is - * left-padded with zeros if the 0/9 sequence comprises more digits than the matching part of - * the decimal value, starts with 0, and is before the decimal point. Otherwise, it is - * padded with spaces. - * '.' or 'D': Specifies the position of the decimal point (optional, only allowed once). - * ',' or 'G': Specifies the position of the grouping (thousands) separator (,). There must be - * a 0 or 9 to the left and right of each grouping separator. - * '$': Specifies the location of the $ currency sign. This character may only be specified - * once. - * 'S' or 'MI': Specifies the position of a '-' or '+' sign (optional, only allowed once at - * the beginning or end of the format string). Note that 'S' prints '+' for positive values - * but 'MI' prints a space. - * 'PR': Only allowed at the end of the format string; specifies that the result string will be - * wrapped by angle brackets if the input value is negative. - * - * If `e` is a datetime, `format` shall be a valid datetime pattern, see - * Datetime Patterns. - * If `e` is a binary, it is converted to a string in one of the formats: - * 'base64': a base 64 string. - * 'hex': a string in the hexadecimal format. - * 'utf-8': the input binary is decoded to UTF-8 string. + * Convert `e` to a string based on the `format`. Throws an exception if the conversion fails. + * The format can consist of the following characters, case insensitive: '0' or '9': Specifies + * an expected digit between 0 and 9. A sequence of 0 or 9 in the format string matches a + * sequence of digits in the input value, generating a result string of the same length as the + * corresponding sequence in the format string. The result string is left-padded with zeros if + * the 0/9 sequence comprises more digits than the matching part of the decimal value, starts + * with 0, and is before the decimal point. Otherwise, it is padded with spaces. '.' or 'D': + * Specifies the position of the decimal point (optional, only allowed once). ',' or 'G': + * Specifies the position of the grouping (thousands) separator (,). There must be a 0 or 9 to + * the left and right of each grouping separator. '$': Specifies the location of the $ currency + * sign. This character may only be specified once. 'S' or 'MI': Specifies the position of a '-' + * or '+' sign (optional, only allowed once at the beginning or end of the format string). Note + * that 'S' prints '+' for positive values but 'MI' prints a space. 'PR': Only allowed at the + * end of the format string; specifies that the result string will be wrapped by angle brackets + * if the input value is negative. + * + * If `e` is a datetime, `format` shall be a valid datetime pattern, see Datetime + * Patterns. If `e` is a binary, it is converted to a string in one of the formats: + * 'base64': a base 64 string. 'hex': a string in the hexadecimal format. 'utf-8': the input + * binary is decoded to UTF-8 string. * * @group string_funcs * @since 3.5.0 @@ -4363,32 +4464,27 @@ object functions { // scalastyle:off line.size.limit /** - * Convert `e` to a string based on the `format`. - * Throws an exception if the conversion fails. The format can consist of the following - * characters, case insensitive: - * '0' or '9': Specifies an expected digit between 0 and 9. A sequence of 0 or 9 in the format - * string matches a sequence of digits in the input value, generating a result string of the - * same length as the corresponding sequence in the format string. The result string is - * left-padded with zeros if the 0/9 sequence comprises more digits than the matching part of - * the decimal value, starts with 0, and is before the decimal point. Otherwise, it is - * padded with spaces. - * '.' or 'D': Specifies the position of the decimal point (optional, only allowed once). - * ',' or 'G': Specifies the position of the grouping (thousands) separator (,). There must be - * a 0 or 9 to the left and right of each grouping separator. - * '$': Specifies the location of the $ currency sign. This character may only be specified - * once. - * 'S' or 'MI': Specifies the position of a '-' or '+' sign (optional, only allowed once at - * the beginning or end of the format string). Note that 'S' prints '+' for positive values - * but 'MI' prints a space. - * 'PR': Only allowed at the end of the format string; specifies that the result string will be - * wrapped by angle brackets if the input value is negative. - * - * If `e` is a datetime, `format` shall be a valid datetime pattern, see - * Datetime Patterns. - * If `e` is a binary, it is converted to a string in one of the formats: - * 'base64': a base 64 string. - * 'hex': a string in the hexadecimal format. - * 'utf-8': the input binary is decoded to UTF-8 string. + * Convert `e` to a string based on the `format`. Throws an exception if the conversion fails. + * The format can consist of the following characters, case insensitive: '0' or '9': Specifies + * an expected digit between 0 and 9. A sequence of 0 or 9 in the format string matches a + * sequence of digits in the input value, generating a result string of the same length as the + * corresponding sequence in the format string. The result string is left-padded with zeros if + * the 0/9 sequence comprises more digits than the matching part of the decimal value, starts + * with 0, and is before the decimal point. Otherwise, it is padded with spaces. '.' or 'D': + * Specifies the position of the decimal point (optional, only allowed once). ',' or 'G': + * Specifies the position of the grouping (thousands) separator (,). There must be a 0 or 9 to + * the left and right of each grouping separator. '$': Specifies the location of the $ currency + * sign. This character may only be specified once. 'S' or 'MI': Specifies the position of a '-' + * or '+' sign (optional, only allowed once at the beginning or end of the format string). Note + * that 'S' prints '+' for positive values but 'MI' prints a space. 'PR': Only allowed at the + * end of the format string; specifies that the result string will be wrapped by angle brackets + * if the input value is negative. + * + * If `e` is a datetime, `format` shall be a valid datetime pattern, see Datetime + * Patterns. If `e` is a binary, it is converted to a string in one of the formats: + * 'base64': a base 64 string. 'hex': a string in the hexadecimal format. 'utf-8': the input + * binary is decoded to UTF-8 string. * * @group string_funcs * @since 3.5.0 @@ -4397,24 +4493,21 @@ object functions { def to_varchar(e: Column, format: Column): Column = Column.fn("to_varchar", e, format) /** - * Convert string 'e' to a number based on the string format 'format'. - * Throws an exception if the conversion fails. The format can consist of the following - * characters, case insensitive: - * '0' or '9': Specifies an expected digit between 0 and 9. A sequence of 0 or 9 in the format - * string matches a sequence of digits in the input string. If the 0/9 sequence starts with - * 0 and is before the decimal point, it can only match a digit sequence of the same size. - * Otherwise, if the sequence starts with 9 or is after the decimal point, it can match a - * digit sequence that has the same or smaller size. - * '.' or 'D': Specifies the position of the decimal point (optional, only allowed once). - * ',' or 'G': Specifies the position of the grouping (thousands) separator (,). There must be - * a 0 or 9 to the left and right of each grouping separator. 'expr' must match the - * grouping separator relevant for the size of the number. - * '$': Specifies the location of the $ currency sign. This character may only be specified - * once. - * 'S' or 'MI': Specifies the position of a '-' or '+' sign (optional, only allowed once at - * the beginning or end of the format string). Note that 'S' allows '-' but 'MI' does not. - * 'PR': Only allowed at the end of the format string; specifies that 'expr' indicates a - * negative number with wrapping angled brackets. + * Convert string 'e' to a number based on the string format 'format'. Throws an exception if + * the conversion fails. The format can consist of the following characters, case insensitive: + * '0' or '9': Specifies an expected digit between 0 and 9. A sequence of 0 or 9 in the format + * string matches a sequence of digits in the input string. If the 0/9 sequence starts with 0 + * and is before the decimal point, it can only match a digit sequence of the same size. + * Otherwise, if the sequence starts with 9 or is after the decimal point, it can match a digit + * sequence that has the same or smaller size. '.' or 'D': Specifies the position of the decimal + * point (optional, only allowed once). ',' or 'G': Specifies the position of the grouping + * (thousands) separator (,). There must be a 0 or 9 to the left and right of each grouping + * separator. 'expr' must match the grouping separator relevant for the size of the number. '$': + * Specifies the location of the $ currency sign. This character may only be specified once. 'S' + * or 'MI': Specifies the position of a '-' or '+' sign (optional, only allowed once at the + * beginning or end of the format string). Note that 'S' allows '-' but 'MI' does not. 'PR': + * Only allowed at the end of the format string; specifies that 'expr' indicates a negative + * number with wrapping angled brackets. * * @group string_funcs * @since 3.5.0 @@ -4452,11 +4545,10 @@ object functions { def replace(src: Column, search: Column): Column = Column.fn("replace", src, search) /** - * Splits `str` by delimiter and return requested part of the split (1-based). - * If any input is null, returns null. if `partNum` is out of range of split parts, - * returns empty string. If `partNum` is 0, throws an error. If `partNum` is negative, - * the parts are counted backward from the end of the string. - * If the `delimiter` is an empty string, the `str` is not split. + * Splits `str` by delimiter and return requested part of the split (1-based). If any input is + * null, returns null. if `partNum` is out of range of split parts, returns empty string. If + * `partNum` is 0, throws an error. If `partNum` is negative, the parts are counted backward + * from the end of the string. If the `delimiter` is an empty string, the `str` is not split. * * @group string_funcs * @since 3.5.0 @@ -4465,8 +4557,8 @@ object functions { Column.fn("split_part", str, delimiter, partNum) /** - * Returns the substring of `str` that starts at `pos` and is of length `len`, - * or the slice of byte array that starts at `pos` and is of length `len`. + * Returns the substring of `str` that starts at `pos` and is of length `len`, or the slice of + * byte array that starts at `pos` and is of length `len`. * * @group string_funcs * @since 3.5.0 @@ -4475,8 +4567,8 @@ object functions { Column.fn("substr", str, pos, len) /** - * Returns the substring of `str` that starts at `pos`, - * or the slice of byte array that starts at `pos`. + * Returns the substring of `str` that starts at `pos`, or the slice of byte array that starts + * at `pos`. * * @group string_funcs * @since 3.5.0 @@ -4511,8 +4603,8 @@ object functions { Column.fn("printf", (format +: arguments): _*) /** - * Decodes a `str` in 'application/x-www-form-urlencoded' format - * using a specific encoding scheme. + * Decodes a `str` in 'application/x-www-form-urlencoded' format using a specific encoding + * scheme. * * @group url_funcs * @since 3.5.0 @@ -4520,8 +4612,8 @@ object functions { def url_decode(str: Column): Column = Column.fn("url_decode", str) /** - * This is a special version of `url_decode` that performs the same operation, but returns - * a NULL value instead of raising an error if the decoding cannot be performed. + * This is a special version of `url_decode` that performs the same operation, but returns a + * NULL value instead of raising an error if the decoding cannot be performed. * * @group url_funcs * @since 4.0.0 @@ -4529,8 +4621,8 @@ object functions { def try_url_decode(str: Column): Column = Column.fn("try_url_decode", str) /** - * Translates a string into 'application/x-www-form-urlencoded' format - * using a specific encoding scheme. + * Translates a string into 'application/x-www-form-urlencoded' format using a specific encoding + * scheme. * * @group url_funcs * @since 3.5.0 @@ -4538,8 +4630,8 @@ object functions { def url_encode(str: Column): Column = Column.fn("url_encode", str) /** - * Returns the position of the first occurrence of `substr` in `str` after position `start`. - * The given `start` and return value are 1-based. + * Returns the position of the first occurrence of `substr` in `str` after position `start`. The + * given `start` and return value are 1-based. * * @group string_funcs * @since 3.5.0 @@ -4548,8 +4640,8 @@ object functions { Column.fn("position", substr, str, start) /** - * Returns the position of the first occurrence of `substr` in `str` after position `1`. - * The return value are 1-based. + * Returns the position of the first occurrence of `substr` in `str` after position `1`. The + * return value are 1-based. * * @group string_funcs * @since 3.5.0 @@ -4558,9 +4650,9 @@ object functions { Column.fn("position", substr, str) /** - * Returns a boolean. The value is True if str ends with suffix. - * Returns NULL if either input expression is NULL. Otherwise, returns False. - * Both str or suffix must be of STRING or BINARY type. + * Returns a boolean. The value is True if str ends with suffix. Returns NULL if either input + * expression is NULL. Otherwise, returns False. Both str or suffix must be of STRING or BINARY + * type. * * @group string_funcs * @since 3.5.0 @@ -4569,9 +4661,9 @@ object functions { Column.fn("endswith", str, suffix) /** - * Returns a boolean. The value is True if str starts with prefix. - * Returns NULL if either input expression is NULL. Otherwise, returns False. - * Both str or prefix must be of STRING or BINARY type. + * Returns a boolean. The value is True if str starts with prefix. Returns NULL if either input + * expression is NULL. Otherwise, returns False. Both str or prefix must be of STRING or BINARY + * type. * * @group string_funcs * @since 3.5.0 @@ -4580,8 +4672,8 @@ object functions { Column.fn("startswith", str, prefix) /** - * Returns the ASCII character having the binary equivalent to `n`. - * If n is larger than 256 the result is equivalent to char(n % 256) + * Returns the ASCII character having the binary equivalent to `n`. If n is larger than 256 the + * result is equivalent to char(n % 256) * * @group string_funcs * @since 3.5.0 @@ -4633,9 +4725,8 @@ object functions { def try_to_number(e: Column, format: Column): Column = Column.fn("try_to_number", e, format) /** - * Returns the character length of string data or number of bytes of binary data. - * The length of string data includes the trailing spaces. - * The length of binary data includes binary zeros. + * Returns the character length of string data or number of bytes of binary data. The length of + * string data includes the trailing spaces. The length of binary data includes binary zeros. * * @group string_funcs * @since 3.5.0 @@ -4643,9 +4734,8 @@ object functions { def char_length(str: Column): Column = Column.fn("char_length", str) /** - * Returns the character length of string data or number of bytes of binary data. - * The length of string data includes the trailing spaces. - * The length of binary data includes binary zeros. + * Returns the character length of string data or number of bytes of binary data. The length of + * string data includes the trailing spaces. The length of binary data includes binary zeros. * * @group string_funcs * @since 3.5.0 @@ -4653,8 +4743,8 @@ object functions { def character_length(str: Column): Column = Column.fn("character_length", str) /** - * Returns the ASCII character having the binary equivalent to `n`. - * If n is larger than 256 the result is equivalent to chr(n % 256) + * Returns the ASCII character having the binary equivalent to `n`. If n is larger than 256 the + * result is equivalent to chr(n % 256) * * @group string_funcs * @since 3.5.0 @@ -4662,9 +4752,9 @@ object functions { def chr(n: Column): Column = Column.fn("chr", n) /** - * Returns a boolean. The value is True if right is found inside left. - * Returns NULL if either input expression is NULL. Otherwise, returns False. - * Both left or right must be of STRING or BINARY type. + * Returns a boolean. The value is True if right is found inside left. Returns NULL if either + * input expression is NULL. Otherwise, returns False. Both left or right must be of STRING or + * BINARY type. * * @group string_funcs * @since 3.5.0 @@ -4672,10 +4762,10 @@ object functions { def contains(left: Column, right: Column): Column = Column.fn("contains", left, right) /** - * Returns the `n`-th input, e.g., returns `input2` when `n` is 2. - * The function returns NULL if the index exceeds the length of the array - * and `spark.sql.ansi.enabled` is set to false. If `spark.sql.ansi.enabled` is set to true, - * it throws ArrayIndexOutOfBoundsException for invalid indices. + * Returns the `n`-th input, e.g., returns `input2` when `n` is 2. The function returns NULL if + * the index exceeds the length of the array and `spark.sql.ansi.enabled` is set to false. If + * `spark.sql.ansi.enabled` is set to true, it throws ArrayIndexOutOfBoundsException for invalid + * indices. * * @group string_funcs * @since 3.5.0 @@ -4684,9 +4774,9 @@ object functions { def elt(inputs: Column*): Column = Column.fn("elt", inputs: _*) /** - * Returns the index (1-based) of the given string (`str`) in the comma-delimited - * list (`strArray`). Returns 0, if the string was not found or if the given string (`str`) - * contains a comma. + * Returns the index (1-based) of the given string (`str`) in the comma-delimited list + * (`strArray`). Returns 0, if the string was not found or if the given string (`str`) contains + * a comma. * * @group string_funcs * @since 3.5.0 @@ -4748,8 +4838,8 @@ object functions { def ucase(str: Column): Column = Column.fn("ucase", str) /** - * Returns the leftmost `len`(`len` can be string type) characters from the string `str`, - * if `len` is less or equal than 0 the result is an empty string. + * Returns the leftmost `len`(`len` can be string type) characters from the string `str`, if + * `len` is less or equal than 0 the result is an empty string. * * @group string_funcs * @since 3.5.0 @@ -4757,8 +4847,8 @@ object functions { def left(str: Column, len: Column): Column = Column.fn("left", str, len) /** - * Returns the rightmost `len`(`len` can be string type) characters from the string `str`, - * if `len` is less or equal than 0 the result is an empty string. + * Returns the rightmost `len`(`len` can be string type) characters from the string `str`, if + * `len` is less or equal than 0 the result is an empty string. * * @group string_funcs * @since 3.5.0 @@ -4772,23 +4862,29 @@ object functions { /** * Returns the date that is `numMonths` after `startDate`. * - * @param startDate A date, timestamp or string. If a string, the data must be in a format that - * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param numMonths The number of months to add to `startDate`, can be negative to subtract months - * @return A date, or null if `startDate` was a string that could not be cast to a date + * @param startDate + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param numMonths + * The number of months to add to `startDate`, can be negative to subtract months + * @return + * A date, or null if `startDate` was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ - def add_months(startDate: Column, numMonths: Int): Column = add_months(startDate, lit(numMonths)) + def add_months(startDate: Column, numMonths: Int): Column = + add_months(startDate, lit(numMonths)) /** * Returns the date that is `numMonths` after `startDate`. * - * @param startDate A date, timestamp or string. If a string, the data must be in a format that - * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param numMonths A column of the number of months to add to `startDate`, can be negative to - * subtract months - * @return A date, or null if `startDate` was a string that could not be cast to a date + * @param startDate + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param numMonths + * A column of the number of months to add to `startDate`, can be negative to subtract months + * @return + * A date, or null if `startDate` was a string that could not be cast to a date * @group datetime_funcs * @since 3.0.0 */ @@ -4796,8 +4892,8 @@ object functions { Column.fn("add_months", startDate, numMonths) /** - * Returns the current date at the start of query evaluation as a date column. - * All calls of current_date within the same query return the same value. + * Returns the current date at the start of query evaluation as a date column. All calls of + * current_date within the same query return the same value. * * @group datetime_funcs * @since 3.5.0 @@ -4805,8 +4901,8 @@ object functions { def curdate(): Column = Column.fn("curdate") /** - * Returns the current date at the start of query evaluation as a date column. - * All calls of current_date within the same query return the same value. + * Returns the current date at the start of query evaluation as a date column. All calls of + * current_date within the same query return the same value. * * @group datetime_funcs * @since 1.5.0 @@ -4822,8 +4918,8 @@ object functions { def current_timezone(): Column = Column.fn("current_timezone") /** - * Returns the current timestamp at the start of query evaluation as a timestamp column. - * All calls of current_timestamp within the same query return the same value. + * Returns the current timestamp at the start of query evaluation as a timestamp column. All + * calls of current_timestamp within the same query return the same value. * * @group datetime_funcs * @since 1.5.0 @@ -4839,9 +4935,9 @@ object functions { def now(): Column = Column.fn("now") /** - * Returns the current timestamp without time zone at the start of query evaluation - * as a timestamp without time zone column. - * All calls of localtimestamp within the same query return the same value. + * Returns the current timestamp without time zone at the start of query evaluation as a + * timestamp without time zone column. All calls of localtimestamp within the same query return + * the same value. * * @group datetime_funcs * @since 3.3.0 @@ -4852,17 +4948,21 @@ object functions { * Converts a date/timestamp/string to a value of string in the format specified by the date * format given by the second argument. * - * See - * Datetime Patterns - * for valid date and time format patterns - * - * @param dateExpr A date, timestamp or string. If a string, the data must be in a format that - * can be cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param format A pattern `dd.MM.yyyy` would return a string like `18.03.1993` - * @return A string, or null if `dateExpr` was a string that could not be cast to a timestamp - * @note Use specialized functions like [[year]] whenever possible as they benefit from a - * specialized implementation. - * @throws IllegalArgumentException if the `format` pattern is invalid + * See Datetime + * Patterns for valid date and time format patterns + * + * @param dateExpr + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param format + * A pattern `dd.MM.yyyy` would return a string like `18.03.1993` + * @return + * A string, or null if `dateExpr` was a string that could not be cast to a timestamp + * @note + * Use specialized functions like [[year]] whenever possible as they benefit from a + * specialized implementation. + * @throws IllegalArgumentException + * if the `format` pattern is invalid * @group datetime_funcs * @since 1.5.0 */ @@ -4872,10 +4972,13 @@ object functions { /** * Returns the date that is `days` days after `start` * - * @param start A date, timestamp or string. If a string, the data must be in a format that - * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param days The number of days to add to `start`, can be negative to subtract days - * @return A date, or null if `start` was a string that could not be cast to a date + * @param start + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param days + * The number of days to add to `start`, can be negative to subtract days + * @return + * A date, or null if `start` was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -4884,10 +4987,13 @@ object functions { /** * Returns the date that is `days` days after `start` * - * @param start A date, timestamp or string. If a string, the data must be in a format that - * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param days A column of the number of days to add to `start`, can be negative to subtract days - * @return A date, or null if `start` was a string that could not be cast to a date + * @param start + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param days + * A column of the number of days to add to `start`, can be negative to subtract days + * @return + * A date, or null if `start` was a string that could not be cast to a date * @group datetime_funcs * @since 3.0.0 */ @@ -4896,10 +5002,13 @@ object functions { /** * Returns the date that is `days` days after `start` * - * @param start A date, timestamp or string. If a string, the data must be in a format that - * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param days A column of the number of days to add to `start`, can be negative to subtract days - * @return A date, or null if `start` was a string that could not be cast to a date + * @param start + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param days + * A column of the number of days to add to `start`, can be negative to subtract days + * @return + * A date, or null if `start` was a string that could not be cast to a date * @group datetime_funcs * @since 3.5.0 */ @@ -4908,10 +5017,13 @@ object functions { /** * Returns the date that is `days` days before `start` * - * @param start A date, timestamp or string. If a string, the data must be in a format that - * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param days The number of days to subtract from `start`, can be negative to add days - * @return A date, or null if `start` was a string that could not be cast to a date + * @param start + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param days + * The number of days to subtract from `start`, can be negative to add days + * @return + * A date, or null if `start` was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -4920,11 +5032,13 @@ object functions { /** * Returns the date that is `days` days before `start` * - * @param start A date, timestamp or string. If a string, the data must be in a format that - * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param days A column of the number of days to subtract from `start`, can be negative to add - * days - * @return A date, or null if `start` was a string that could not be cast to a date + * @param start + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param days + * A column of the number of days to subtract from `start`, can be negative to add days + * @return + * A date, or null if `start` was a string that could not be cast to a date * @group datetime_funcs * @since 3.0.0 */ @@ -4940,12 +5054,15 @@ object functions { * // returns 1 * }}} * - * @param end A date, timestamp or string. If a string, the data must be in a format that - * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param start A date, timestamp or string. If a string, the data must be in a format that - * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @return An integer, or null if either `end` or `start` were strings that could not be cast to - * a date. Negative if `end` is before `start` + * @param end + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param start + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return + * An integer, or null if either `end` or `start` were strings that could not be cast to a + * date. Negative if `end` is before `start` * @group datetime_funcs * @since 1.5.0 */ @@ -4960,12 +5077,15 @@ object functions { * // returns 1 * }}} * - * @param end A date, timestamp or string. If a string, the data must be in a format that - * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param start A date, timestamp or string. If a string, the data must be in a format that - * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @return An integer, or null if either `end` or `start` were strings that could not be cast to - * a date. Negative if `end` is before `start` + * @param end + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param start + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return + * An integer, or null if either `end` or `start` were strings that could not be cast to a + * date. Negative if `end` is before `start` * @group datetime_funcs * @since 3.5.0 */ @@ -4981,7 +5101,8 @@ object functions { /** * Extracts the year as an integer from a given date/timestamp/string. - * @return An integer, or null if the input was a string that could not be cast to a date + * @return + * An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -4989,7 +5110,8 @@ object functions { /** * Extracts the quarter as an integer from a given date/timestamp/string. - * @return An integer, or null if the input was a string that could not be cast to a date + * @return + * An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -4997,16 +5119,18 @@ object functions { /** * Extracts the month as an integer from a given date/timestamp/string. - * @return An integer, or null if the input was a string that could not be cast to a date + * @return + * An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ def month(e: Column): Column = Column.fn("month", e) /** - * Extracts the day of the week as an integer from a given date/timestamp/string. - * Ranges from 1 for a Sunday through to 7 for a Saturday - * @return An integer, or null if the input was a string that could not be cast to a date + * Extracts the day of the week as an integer from a given date/timestamp/string. Ranges from 1 + * for a Sunday through to 7 for a Saturday + * @return + * An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 2.3.0 */ @@ -5014,7 +5138,8 @@ object functions { /** * Extracts the day of the month as an integer from a given date/timestamp/string. - * @return An integer, or null if the input was a string that could not be cast to a date + * @return + * An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -5022,7 +5147,8 @@ object functions { /** * Extracts the day of the month as an integer from a given date/timestamp/string. - * @return An integer, or null if the input was a string that could not be cast to a date + * @return + * An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 3.5.0 */ @@ -5030,7 +5156,8 @@ object functions { /** * Extracts the day of the year as an integer from a given date/timestamp/string. - * @return An integer, or null if the input was a string that could not be cast to a date + * @return + * An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -5038,7 +5165,8 @@ object functions { /** * Extracts the hours as an integer from a given date/timestamp/string. - * @return An integer, or null if the input was a string that could not be cast to a date + * @return + * An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -5047,9 +5175,12 @@ object functions { /** * Extracts a part of the date/timestamp or interval source. * - * @param field selects which part of the source should be extracted. - * @param source a date/timestamp or interval column from where `field` should be extracted. - * @return a part of the date/timestamp or interval source + * @param field + * selects which part of the source should be extracted. + * @param source + * a date/timestamp or interval column from where `field` should be extracted. + * @return + * a part of the date/timestamp or interval source * @group datetime_funcs * @since 3.5.0 */ @@ -5060,10 +5191,13 @@ object functions { /** * Extracts a part of the date/timestamp or interval source. * - * @param field selects which part of the source should be extracted, and supported string values - * are as same as the fields of the equivalent function `extract`. - * @param source a date/timestamp or interval column from where `field` should be extracted. - * @return a part of the date/timestamp or interval source + * @param field + * selects which part of the source should be extracted, and supported string values are as + * same as the fields of the equivalent function `extract`. + * @param source + * a date/timestamp or interval column from where `field` should be extracted. + * @return + * a part of the date/timestamp or interval source * @group datetime_funcs * @since 3.5.0 */ @@ -5074,10 +5208,13 @@ object functions { /** * Extracts a part of the date/timestamp or interval source. * - * @param field selects which part of the source should be extracted, and supported string values - * are as same as the fields of the equivalent function `EXTRACT`. - * @param source a date/timestamp or interval column from where `field` should be extracted. - * @return a part of the date/timestamp or interval source + * @param field + * selects which part of the source should be extracted, and supported string values are as + * same as the fields of the equivalent function `EXTRACT`. + * @param source + * a date/timestamp or interval column from where `field` should be extracted. + * @return + * a part of the date/timestamp or interval source * @group datetime_funcs * @since 3.5.0 */ @@ -5086,13 +5223,14 @@ object functions { } /** - * Returns the last day of the month which the given date belongs to. - * For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the - * month in July 2015. + * Returns the last day of the month which the given date belongs to. For example, input + * "2015-07-27" returns "2015-07-31" since July 31 is the last day of the month in July 2015. * - * @param e A date, timestamp or string. If a string, the data must be in a format that can be - * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @return A date, or null if the input was a string that could not be cast to a date + * @param e + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return + * A date, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -5100,7 +5238,8 @@ object functions { /** * Extracts the minutes as an integer from a given date/timestamp/string. - * @return An integer, or null if the input was a string that could not be cast to a date + * @return + * An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -5115,7 +5254,8 @@ object functions { def weekday(e: Column): Column = Column.fn("weekday", e) /** - * @return A date created from year, month and day fields. + * @return + * A date created from year, month and day fields. * @group datetime_funcs * @since 3.3.0 */ @@ -5126,7 +5266,8 @@ object functions { * Returns number of months between dates `start` and `end`. * * A whole number is returned if both inputs have the same day of month or both are the last day - * of their respective months. Otherwise, the difference is calculated assuming 31 days per month. + * of their respective months. Otherwise, the difference is calculated assuming 31 days per + * month. * * For example: * {{{ @@ -5135,12 +5276,15 @@ object functions { * months_between("2017-06-01", "2017-06-16 12:00:00") // returns -0.5 * }}} * - * @param end A date, timestamp or string. If a string, the data must be in a format that can - * be cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param start A date, timestamp or string. If a string, the data must be in a format that can - * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @return A double, or null if either `end` or `start` were strings that could not be cast to a - * timestamp. Negative if `end` is before `start` + * @param end + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param start + * A date, timestamp or string. If a string, the data must be in a format that can cast to a + * timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return + * A double, or null if either `end` or `start` were strings that could not be cast to a + * timestamp. Negative if `end` is before `start` * @group datetime_funcs * @since 1.5.0 */ @@ -5163,11 +5307,14 @@ object functions { * For example, `next_day('2015-07-27', "Sunday")` returns 2015-08-02 because that is the first * Sunday after 2015-07-27. * - * @param date A date, timestamp or string. If a string, the data must be in a format that - * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param dayOfWeek Case insensitive, and accepts: "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun" - * @return A date, or null if `date` was a string that could not be cast to a date or if - * `dayOfWeek` was an invalid value + * @param date + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param dayOfWeek + * Case insensitive, and accepts: "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun" + * @return + * A date, or null if `date` was a string that could not be cast to a date or if `dayOfWeek` + * was an invalid value * @group datetime_funcs * @since 1.5.0 */ @@ -5180,12 +5327,15 @@ object functions { * For example, `next_day('2015-07-27', "Sunday")` returns 2015-08-02 because that is the first * Sunday after 2015-07-27. * - * @param date A date, timestamp or string. If a string, the data must be in a format that - * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param dayOfWeek A column of the day of week. Case insensitive, and accepts: "Mon", "Tue", - * "Wed", "Thu", "Fri", "Sat", "Sun" - * @return A date, or null if `date` was a string that could not be cast to a date or if - * `dayOfWeek` was an invalid value + * @param date + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param dayOfWeek + * A column of the day of week. Case insensitive, and accepts: "Mon", "Tue", "Wed", "Thu", + * "Fri", "Sat", "Sun" + * @return + * A date, or null if `date` was a string that could not be cast to a date or if `dayOfWeek` + * was an invalid value * @group datetime_funcs * @since 3.2.0 */ @@ -5194,7 +5344,8 @@ object functions { /** * Extracts the seconds as an integer from a given date/timestamp/string. - * @return An integer, or null if the input was a string that could not be cast to a timestamp + * @return + * An integer, or null if the input was a string that could not be cast to a timestamp * @group datetime_funcs * @since 1.5.0 */ @@ -5206,7 +5357,8 @@ object functions { * A week is considered to start on a Monday and week 1 is the first week with more than 3 days, * as defined by ISO 8601 * - * @return An integer, or null if the input was a string that could not be cast to a date + * @return + * An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -5214,12 +5366,14 @@ object functions { /** * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string - * representing the timestamp of that moment in the current system time zone in the - * yyyy-MM-dd HH:mm:ss format. - * - * @param ut A number of a type that is castable to a long, such as string or integer. Can be - * negative for timestamps before the unix epoch - * @return A string, or null if the input was a string that could not be cast to a long + * representing the timestamp of that moment in the current system time zone in the yyyy-MM-dd + * HH:mm:ss format. + * + * @param ut + * A number of a type that is castable to a long, such as string or integer. Can be negative + * for timestamps before the unix epoch + * @return + * A string, or null if the input was a string that could not be cast to a long * @group datetime_funcs * @since 1.5.0 */ @@ -5230,15 +5384,17 @@ object functions { * representing the timestamp of that moment in the current system time zone in the given * format. * - * See - * Datetime Patterns - * for valid date and time format patterns - * - * @param ut A number of a type that is castable to a long, such as string or integer. Can be - * negative for timestamps before the unix epoch - * @param f A date time pattern that the input will be formatted to - * @return A string, or null if `ut` was a string that could not be cast to a long or `f` was - * an invalid date time pattern + * See Datetime + * Patterns for valid date and time format patterns + * + * @param ut + * A number of a type that is castable to a long, such as string or integer. Can be negative + * for timestamps before the unix epoch + * @param f + * A date time pattern that the input will be formatted to + * @return + * A string, or null if `ut` was a string that could not be cast to a long or `f` was an + * invalid date time pattern * @group datetime_funcs * @since 1.5.0 */ @@ -5248,8 +5404,9 @@ object functions { /** * Returns the current Unix timestamp (in seconds) as a long. * - * @note All calls of `unix_timestamp` within the same query return the same value - * (i.e. the current timestamp is calculated at the start of query evaluation). + * @note + * All calls of `unix_timestamp` within the same query return the same value (i.e. the current + * timestamp is calculated at the start of query evaluation). * * @group datetime_funcs * @since 1.5.0 @@ -5257,12 +5414,14 @@ object functions { def unix_timestamp(): Column = unix_timestamp(current_timestamp()) /** - * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), - * using the default timezone and the default locale. + * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), using the + * default timezone and the default locale. * - * @param s A date, timestamp or string. If a string, the data must be in the - * `yyyy-MM-dd HH:mm:ss` format - * @return A long, or null if the input was a string not of the correct format + * @param s + * A date, timestamp or string. If a string, the data must be in the `yyyy-MM-dd HH:mm:ss` + * format + * @return + * A long, or null if the input was a string not of the correct format * @group datetime_funcs * @since 1.5.0 */ @@ -5271,15 +5430,17 @@ object functions { /** * Converts time string with given pattern to Unix timestamp (in seconds). * - * See - * Datetime Patterns - * for valid date and time format patterns - * - * @param s A date, timestamp or string. If a string, the data must be in a format that can be - * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param p A date time pattern detailing the format of `s` when `s` is a string - * @return A long, or null if `s` was a string that could not be cast to a date or `p` was - * an invalid format + * See Datetime + * Patterns for valid date and time format patterns + * + * @param s + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param p + * A date time pattern detailing the format of `s` when `s` is a string + * @return + * A long, or null if `s` was a string that could not be cast to a date or `p` was an invalid + * format * @group datetime_funcs * @since 1.5.0 */ @@ -5289,9 +5450,11 @@ object functions { /** * Converts to a timestamp by casting rules to `TimestampType`. * - * @param s A date, timestamp or string. If a string, the data must be in a format that can be - * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @return A timestamp, or null if the input was a string that could not be cast to a timestamp + * @param s + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return + * A timestamp, or null if the input was a string that could not be cast to a timestamp * @group datetime_funcs * @since 2.2.0 */ @@ -5300,15 +5463,17 @@ object functions { /** * Converts time string with the given pattern to timestamp. * - * See - * Datetime Patterns - * for valid date and time format patterns - * - * @param s A date, timestamp or string. If a string, the data must be in a format that can be - * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param fmt A date time pattern detailing the format of `s` when `s` is a string - * @return A timestamp, or null if `s` was a string that could not be cast to a timestamp or - * `fmt` was an invalid format + * See Datetime + * Patterns for valid date and time format patterns + * + * @param s + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param fmt + * A date time pattern detailing the format of `s` when `s` is a string + * @return + * A timestamp, or null if `s` was a string that could not be cast to a timestamp or `fmt` was + * an invalid format * @group datetime_funcs * @since 2.2.0 */ @@ -5326,9 +5491,9 @@ object functions { Column.fn("try_to_timestamp", s, format) /** - * Parses the `s` to a timestamp. The function always returns null on an invalid - * input with`/`without ANSI SQL mode enabled. It follows casting rules to a timestamp. The - * result data type is consistent with the value of configuration `spark.sql.timestampType`. + * Parses the `s` to a timestamp. The function always returns null on an invalid input + * with`/`without ANSI SQL mode enabled. It follows casting rules to a timestamp. The result + * data type is consistent with the value of configuration `spark.sql.timestampType`. * * @group datetime_funcs * @since 3.5.0 @@ -5346,15 +5511,17 @@ object functions { /** * Converts the column into a `DateType` with a specified format * - * See - * Datetime Patterns - * for valid date and time format patterns - * - * @param e A date, timestamp or string. If a string, the data must be in a format that can be - * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param fmt A date time pattern detailing the format of `e` when `e`is a string - * @return A date, or null if `e` was a string that could not be cast to a date or `fmt` was an - * invalid format + * See Datetime + * Patterns for valid date and time format patterns + * + * @param e + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param fmt + * A date time pattern detailing the format of `e` when `e`is a string + * @return + * A date, or null if `e` was a string that could not be cast to a date or `fmt` was an + * invalid format * @group datetime_funcs * @since 2.2.0 */ @@ -5377,8 +5544,8 @@ object functions { def unix_micros(e: Column): Column = Column.fn("unix_micros", e) /** - * Returns the number of milliseconds since 1970-01-01 00:00:00 UTC. - * Truncates higher levels of precision. + * Returns the number of milliseconds since 1970-01-01 00:00:00 UTC. Truncates higher levels of + * precision. * * @group datetime_funcs * @since 3.5.0 @@ -5386,8 +5553,8 @@ object functions { def unix_millis(e: Column): Column = Column.fn("unix_millis", e) /** - * Returns the number of seconds since 1970-01-01 00:00:00 UTC. - * Truncates higher levels of precision. + * Returns the number of seconds since 1970-01-01 00:00:00 UTC. Truncates higher levels of + * precision. * * @group datetime_funcs * @since 3.5.0 @@ -5399,14 +5566,16 @@ object functions { * * For example, `trunc("2018-11-19 12:01:19", "year")` returns 2018-01-01 * - * @param date A date, timestamp or string. If a string, the data must be in a format that can be - * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param format: 'year', 'yyyy', 'yy' to truncate by year, - * or 'month', 'mon', 'mm' to truncate by month - * Other options are: 'week', 'quarter' + * @param date + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param format: + * 'year', 'yyyy', 'yy' to truncate by year, or 'month', 'mon', 'mm' to truncate by month + * Other options are: 'week', 'quarter' * - * @return A date, or null if `date` was a string that could not be cast to a date or `format` - * was an invalid value + * @return + * A date, or null if `date` was a string that could not be cast to a date or `format` was an + * invalid value * @group datetime_funcs * @since 1.5.0 */ @@ -5417,15 +5586,16 @@ object functions { * * For example, `date_trunc("year", "2018-11-19 12:01:19")` returns 2018-01-01 00:00:00 * - * @param format: 'year', 'yyyy', 'yy' to truncate by year, - * 'month', 'mon', 'mm' to truncate by month, - * 'day', 'dd' to truncate by day, - * Other options are: - * 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'week', 'quarter' - * @param timestamp A date, timestamp or string. If a string, the data must be in a format that - * can be cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @return A timestamp, or null if `timestamp` was a string that could not be cast to a timestamp - * or `format` was an invalid value + * @param format: + * 'year', 'yyyy', 'yy' to truncate by year, 'month', 'mon', 'mm' to truncate by month, 'day', + * 'dd' to truncate by day, Other options are: 'microsecond', 'millisecond', 'second', + * 'minute', 'hour', 'week', 'quarter' + * @param timestamp + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return + * A timestamp, or null if `timestamp` was a string that could not be cast to a timestamp or + * `format` was an invalid value * @group datetime_funcs * @since 2.3.0 */ @@ -5434,19 +5604,21 @@ object functions { /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders - * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield - * '2017-07-14 03:40:00.0'. - * - * @param ts A date, timestamp or string. If a string, the data must be in a format that can be - * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param tz A string detailing the time zone ID that the input should be adjusted to. It should - * be in the format of either region-based zone IDs or zone offsets. Region IDs must - * have the form 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in - * the format '(+|-)HH:mm', for example '-08:00' or '+01:00'. Also 'UTC' and 'Z' are - * supported as aliases of '+00:00'. Other short names are not recommended to use - * because they can be ambiguous. - * @return A timestamp, or null if `ts` was a string that could not be cast to a timestamp or - * `tz` was an invalid value + * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield '2017-07-14 + * 03:40:00.0'. + * + * @param ts + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param tz + * A string detailing the time zone ID that the input should be adjusted to. It should be in + * the format of either region-based zone IDs or zone offsets. Region IDs must have the form + * 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in the format + * '(+|-)HH:mm', for example '-08:00' or '+01:00'. Also 'UTC' and 'Z' are supported as aliases + * of '+00:00'. Other short names are not recommended to use because they can be ambiguous. + * @return + * A timestamp, or null if `ts` was a string that could not be cast to a timestamp or `tz` was + * an invalid value * @group datetime_funcs * @since 1.5.0 */ @@ -5454,8 +5626,8 @@ object functions { /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders - * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield - * '2017-07-14 03:40:00.0'. + * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield '2017-07-14 + * 03:40:00.0'. * @group datetime_funcs * @since 2.4.0 */ @@ -5467,16 +5639,18 @@ object functions { * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield * '2017-07-14 01:40:00.0'. * - * @param ts A date, timestamp or string. If a string, the data must be in a format that can be - * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` - * @param tz A string detailing the time zone ID that the input should be adjusted to. It should - * be in the format of either region-based zone IDs or zone offsets. Region IDs must - * have the form 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in - * the format '(+|-)HH:mm', for example '-08:00' or '+01:00'. Also 'UTC' and 'Z' are - * supported as aliases of '+00:00'. Other short names are not recommended to use - * because they can be ambiguous. - * @return A timestamp, or null if `ts` was a string that could not be cast to a timestamp or - * `tz` was an invalid value + * @param ts + * A date, timestamp or string. If a string, the data must be in a format that can be cast to + * a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param tz + * A string detailing the time zone ID that the input should be adjusted to. It should be in + * the format of either region-based zone IDs or zone offsets. Region IDs must have the form + * 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in the format + * '(+|-)HH:mm', for example '-08:00' or '+01:00'. Also 'UTC' and 'Z' are supported as aliases + * of '+00:00'. Other short names are not recommended to use because they can be ambiguous. + * @return + * A timestamp, or null if `ts` was a string that could not be cast to a timestamp or `tz` was + * an invalid value * @group datetime_funcs * @since 1.5.0 */ @@ -5495,8 +5669,8 @@ object functions { * Bucketize rows into one or more time windows given a timestamp specifying column. Window * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window * [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in - * the order of months are not supported. The following example takes the average stock price for - * a one minute window every 10 seconds starting 5 seconds after the hour: + * the order of months are not supported. The following example takes the average stock price + * for a one minute window every 10 seconds starting 5 seconds after the hour: * * {{{ * val df = ... // schema => timestamp: TimestampType, stockId: StringType, price: DoubleType @@ -5515,23 +5689,23 @@ object functions { * For a streaming query, you may use the function `current_timestamp` to generate windows on * processing time. * - * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time column must be of TimestampType or TimestampNTZType. - * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, - * `1 second`. Check `org.apache.spark.unsafe.types.CalendarInterval` for - * valid duration identifiers. Note that the duration is a fixed length of - * time, and does not vary over time according to a calendar. For example, - * `1 day` always means 86,400,000 milliseconds, not a calendar day. - * @param slideDuration A string specifying the sliding interval of the window, e.g. `1 minute`. - * A new window will be generated every `slideDuration`. Must be less than - * or equal to the `windowDuration`. Check - * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration - * identifiers. This duration is likewise absolute, and does not vary - * according to a calendar. - * @param startTime The offset with respect to 1970-01-01 00:00:00 UTC with which to start - * window intervals. For example, in order to have hourly tumbling windows that - * start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide - * `startTime` as `15 minutes`. + * @param timeColumn + * The column or the expression to use as the timestamp for windowing by time. The time column + * must be of TimestampType or TimestampNTZType. + * @param windowDuration + * A string specifying the width of the window, e.g. `10 minutes`, `1 second`. Check + * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration identifiers. Note that + * the duration is a fixed length of time, and does not vary over time according to a + * calendar. For example, `1 day` always means 86,400,000 milliseconds, not a calendar day. + * @param slideDuration + * A string specifying the sliding interval of the window, e.g. `1 minute`. A new window will + * be generated every `slideDuration`. Must be less than or equal to the `windowDuration`. + * Check `org.apache.spark.unsafe.types.CalendarInterval` for valid duration identifiers. This + * duration is likewise absolute, and does not vary according to a calendar. + * @param startTime + * The offset with respect to 1970-01-01 00:00:00 UTC with which to start window intervals. + * For example, in order to have hourly tumbling windows that start 15 minutes past the hour, + * e.g. 12:15-13:15, 13:15-14:15... provide `startTime` as `15 minutes`. * * @group datetime_funcs * @since 2.0.0 @@ -5547,8 +5721,9 @@ object functions { * Bucketize rows into one or more time windows given a timestamp specifying column. Window * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window * [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in - * the order of months are not supported. The windows start beginning at 1970-01-01 00:00:00 UTC. - * The following example takes the average stock price for a one minute window every 10 seconds: + * the order of months are not supported. The windows start beginning at 1970-01-01 00:00:00 + * UTC. The following example takes the average stock price for a one minute window every 10 + * seconds: * * {{{ * val df = ... // schema => timestamp: TimestampType, stockId: StringType, price: DoubleType @@ -5567,19 +5742,19 @@ object functions { * For a streaming query, you may use the function `current_timestamp` to generate windows on * processing time. * - * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time column must be of TimestampType or TimestampNTZType. - * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, - * `1 second`. Check `org.apache.spark.unsafe.types.CalendarInterval` for - * valid duration identifiers. Note that the duration is a fixed length of - * time, and does not vary over time according to a calendar. For example, - * `1 day` always means 86,400,000 milliseconds, not a calendar day. - * @param slideDuration A string specifying the sliding interval of the window, e.g. `1 minute`. - * A new window will be generated every `slideDuration`. Must be less than - * or equal to the `windowDuration`. Check - * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration - * identifiers. This duration is likewise absolute, and does not vary - * according to a calendar. + * @param timeColumn + * The column or the expression to use as the timestamp for windowing by time. The time column + * must be of TimestampType or TimestampNTZType. + * @param windowDuration + * A string specifying the width of the window, e.g. `10 minutes`, `1 second`. Check + * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration identifiers. Note that + * the duration is a fixed length of time, and does not vary over time according to a + * calendar. For example, `1 day` always means 86,400,000 milliseconds, not a calendar day. + * @param slideDuration + * A string specifying the sliding interval of the window, e.g. `1 minute`. A new window will + * be generated every `slideDuration`. Must be less than or equal to the `windowDuration`. + * Check `org.apache.spark.unsafe.types.CalendarInterval` for valid duration identifiers. This + * duration is likewise absolute, and does not vary according to a calendar. * * @group datetime_funcs * @since 2.0.0 @@ -5589,11 +5764,11 @@ object functions { } /** - * Generates tumbling time windows given a timestamp specifying column. Window - * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window - * [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in - * the order of months are not supported. The windows start beginning at 1970-01-01 00:00:00 UTC. - * The following example takes the average stock price for a one minute tumbling window: + * Generates tumbling time windows given a timestamp specifying column. Window starts are + * inclusive but the window ends are exclusive, e.g. 12:05 will be in the window [12:05,12:10) + * but not in [12:00,12:05). Windows can support microsecond precision. Windows in the order of + * months are not supported. The windows start beginning at 1970-01-01 00:00:00 UTC. The + * following example takes the average stock price for a one minute tumbling window: * * {{{ * val df = ... // schema => timestamp: TimestampType, stockId: StringType, price: DoubleType @@ -5612,11 +5787,12 @@ object functions { * For a streaming query, you may use the function `current_timestamp` to generate windows on * processing time. * - * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time column must be of TimestampType or TimestampNTZType. - * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, - * `1 second`. Check `org.apache.spark.unsafe.types.CalendarInterval` for - * valid duration identifiers. + * @param timeColumn + * The column or the expression to use as the timestamp for windowing by time. The time column + * must be of TimestampType or TimestampNTZType. + * @param windowDuration + * A string specifying the width of the window, e.g. `10 minutes`, `1 second`. Check + * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration identifiers. * * @group datetime_funcs * @since 2.0.0 @@ -5632,8 +5808,9 @@ object functions { * inclusive and end is exclusive. Since event time can support microsecond precision, * window_time(window) = window.end - 1 microsecond. * - * @param windowColumn The window column (typically produced by window aggregation) of type - * StructType { start: Timestamp, end: Timestamp } + * @param windowColumn + * The window column (typically produced by window aggregation) of type StructType { start: + * Timestamp, end: Timestamp } * * @group datetime_funcs * @since 3.4.0 @@ -5644,10 +5821,9 @@ object functions { * Generates session window given a timestamp specifying column. * * Session window is one of dynamic windows, which means the length of window is varying - * according to the given inputs. The length of session window is defined as "the timestamp - * of latest input of the session + gap duration", so when the new inputs are bound to the - * current session window, the end time of session window can be expanded according to the new - * inputs. + * according to the given inputs. The length of session window is defined as "the timestamp of + * latest input of the session + gap duration", so when the new inputs are bound to the current + * session window, the end time of session window can be expanded according to the new inputs. * * Windows can support microsecond precision. gapDuration in the order of months are not * supported. @@ -5655,11 +5831,12 @@ object functions { * For a streaming query, you may use the function `current_timestamp` to generate windows on * processing time. * - * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time column must be of TimestampType or TimestampNTZType. - * @param gapDuration A string specifying the timeout of the session, e.g. `10 minutes`, - * `1 second`. Check `org.apache.spark.unsafe.types.CalendarInterval` for - * valid duration identifiers. + * @param timeColumn + * The column or the expression to use as the timestamp for windowing by time. The time column + * must be of TimestampType or TimestampNTZType. + * @param gapDuration + * A string specifying the timeout of the session, e.g. `10 minutes`, `1 second`. Check + * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration identifiers. * * @group datetime_funcs * @since 3.2.0 @@ -5671,17 +5848,17 @@ object functions { * Generates session window given a timestamp specifying column. * * Session window is one of dynamic windows, which means the length of window is varying - * according to the given inputs. For static gap duration, the length of session window - * is defined as "the timestamp of latest input of the session + gap duration", so when - * the new inputs are bound to the current session window, the end time of session window - * can be expanded according to the new inputs. - * - * Besides a static gap duration value, users can also provide an expression to specify - * gap duration dynamically based on the input row. With dynamic gap duration, the closing - * of a session window does not depend on the latest input anymore. A session window's range - * is the union of all events' ranges which are determined by event start time and evaluated - * gap duration during the query execution. Note that the rows with negative or zero gap - * duration will be filtered out from the aggregation. + * according to the given inputs. For static gap duration, the length of session window is + * defined as "the timestamp of latest input of the session + gap duration", so when the new + * inputs are bound to the current session window, the end time of session window can be + * expanded according to the new inputs. + * + * Besides a static gap duration value, users can also provide an expression to specify gap + * duration dynamically based on the input row. With dynamic gap duration, the closing of a + * session window does not depend on the latest input anymore. A session window's range is the + * union of all events' ranges which are determined by event start time and evaluated gap + * duration during the query execution. Note that the rows with negative or zero gap duration + * will be filtered out from the aggregation. * * Windows can support microsecond precision. gapDuration in the order of months are not * supported. @@ -5689,11 +5866,13 @@ object functions { * For a streaming query, you may use the function `current_timestamp` to generate windows on * processing time. * - * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time column must be of TimestampType or TimestampNTZType. - * @param gapDuration A column specifying the timeout of the session. It could be static value, - * e.g. `10 minutes`, `1 second`, or an expression/UDF that specifies gap - * duration dynamically based on the input row. + * @param timeColumn + * The column or the expression to use as the timestamp for windowing by time. The time column + * must be of TimestampType or TimestampNTZType. + * @param gapDuration + * A column specifying the timeout of the session. It could be static value, e.g. `10 + * minutes`, `1 second`, or an expression/UDF that specifies gap duration dynamically based on + * the input row. * * @group datetime_funcs * @since 3.2.0 @@ -5702,8 +5881,7 @@ object functions { Column.fn("session_window", timeColumn, gapDuration) /** - * Converts the number of seconds from the Unix epoch (1970-01-01T00:00:00Z) - * to a timestamp. + * Converts the number of seconds from the Unix epoch (1970-01-01T00:00:00Z) to a timestamp. * @group datetime_funcs * @since 3.1.0 */ @@ -5726,8 +5904,8 @@ object functions { def timestamp_micros(e: Column): Column = Column.fn("timestamp_micros", e) /** - * Gets the difference between the timestamps in the specified units by truncating - * the fraction part. + * Gets the difference between the timestamps in the specified units by truncating the fraction + * part. * * @group datetime_funcs * @since 4.0.0 @@ -5745,8 +5923,8 @@ object functions { Column.internalFn("timestampadd", lit(unit), quantity, ts) /** - * Parses the `timestamp` expression with the `format` expression - * to a timestamp without time zone. Returns null with invalid input. + * Parses the `timestamp` expression with the `format` expression to a timestamp without time + * zone. Returns null with invalid input. * * @group datetime_funcs * @since 3.5.0 @@ -5765,8 +5943,8 @@ object functions { Column.fn("to_timestamp_ltz", timestamp) /** - * Parses the `timestamp_str` expression with the `format` expression - * to a timestamp without time zone. Returns null with invalid input. + * Parses the `timestamp_str` expression with the `format` expression to a timestamp without + * time zone. Returns null with invalid input. * * @group datetime_funcs * @since 3.5.0 @@ -5855,9 +6033,12 @@ object functions { * Returns an array containing all the elements in `x` from index `start` (or starting from the * end if `start` is negative) with the specified `length`. * - * @param x the array column to be sliced - * @param start the starting index - * @param length the length of the slice + * @param x + * the array column to be sliced + * @param start + * the starting index + * @param length + * the length of the slice * * @group array_funcs * @since 2.4.0 @@ -5869,9 +6050,12 @@ object functions { * Returns an array containing all the elements in `x` from index `start` (or starting from the * end if `start` is negative) with the specified `length`. * - * @param x the array column to be sliced - * @param start the starting index - * @param length the length of the slice + * @param x + * the array column to be sliced + * @param start + * the starting index + * @param length + * the length of the slice * * @group array_funcs * @since 3.1.0 @@ -5897,10 +6081,11 @@ object functions { Column.fn("array_join", column, lit(delimiter)) /** - * Concatenates multiple input columns together into a single column. - * The function works with strings, binary and compatible array columns. + * Concatenates multiple input columns together into a single column. The function works with + * strings, binary and compatible array columns. * - * @note Returns null if any of the input columns are null. + * @note + * Returns null if any of the input columns are null. * * @group collection_funcs * @since 1.5.0 @@ -5909,11 +6094,12 @@ object functions { def concat(exprs: Column*): Column = Column.fn("concat", exprs: _*) /** - * Locates the position of the first occurrence of the value in the given array as long. - * Returns null if either of the arguments are null. + * Locates the position of the first occurrence of the value in the given array as long. Returns + * null if either of the arguments are null. * - * @note The position is not zero based, but 1 based index. Returns 0 if value - * could not be found in array. + * @note + * The position is not zero based, but 1 based index. Returns 0 if value could not be found in + * array. * * @group array_funcs * @since 2.4.0 @@ -5922,8 +6108,8 @@ object functions { Column.fn("array_position", column, lit(value)) /** - * Returns element of array at given index in value if column is array. Returns value for - * the given key in value if column is map. + * Returns element of array at given index in value if column is array. Returns value for the + * given key in value if column is map. * * @group collection_funcs * @since 2.4.0 @@ -5945,8 +6131,8 @@ object functions { Column.fn("try_element_at", column, value) /** - * Returns element of array at given (0-based) index. If the index points - * outside of the array boundaries, then this function returns NULL. + * Returns element of array at given (0-based) index. If the index points outside of the array + * boundaries, then this function returns NULL. * * @group array_funcs * @since 3.4.0 @@ -5955,8 +6141,8 @@ object functions { /** * Sorts the input array in ascending order. The elements of the input array must be orderable. - * NaN is greater than any non-NaN elements for double/float type. - * Null elements will be placed at the end of the returned array. + * NaN is greater than any non-NaN elements for double/float type. Null elements will be placed + * at the end of the returned array. * * @group collection_funcs * @since 2.4.0 @@ -6010,8 +6196,8 @@ object functions { def array_distinct(e: Column): Column = Column.fn("array_distinct", e) /** - * Returns an array of the elements in the intersection of the given two arrays, - * without duplicates. + * Returns an array of the elements in the intersection of the given two arrays, without + * duplicates. * * @group array_funcs * @since 2.4.0 @@ -6038,8 +6224,8 @@ object functions { Column.fn("array_union", col1, col2) /** - * Returns an array of the elements in the first array but not in the second array, - * without duplicates. The order of elements in the result is not determined + * Returns an array of the elements in the first array but not in the second array, without + * duplicates. The order of elements in the result is not determined * * @group array_funcs * @since 2.4.0 @@ -6047,7 +6233,6 @@ object functions { def array_except(col1: Column, col2: Column): Column = Column.fn("array_except", col1, col2) - private def createLambda(f: Column => Column) = { val x = internal.UnresolvedNamedLambdaVariable("x") val function = f(Column(x)).node @@ -6070,14 +6255,16 @@ object functions { } /** - * Returns an array of elements after applying a transformation to each element - * in the input array. + * Returns an array of elements after applying a transformation to each element in the input + * array. * {{{ * df.select(transform(col("i"), x => x + 1)) * }}} * - * @param column the input array column - * @param f col => transformed_col, the lambda function to transform the input column + * @param column + * the input array column + * @param f + * col => transformed_col, the lambda function to transform the input column * * @group collection_funcs * @since 3.0.0 @@ -6086,15 +6273,17 @@ object functions { Column.fn("transform", column, createLambda(f)) /** - * Returns an array of elements after applying a transformation to each element - * in the input array. + * Returns an array of elements after applying a transformation to each element in the input + * array. * {{{ * df.select(transform(col("i"), (x, i) => x + i)) * }}} * - * @param column the input array column - * @param f (col, index) => transformed_col, the lambda function to transform the input - * column given the index. Indices start at 0. + * @param column + * the input array column + * @param f + * (col, index) => transformed_col, the lambda function to transform the input column given + * the index. Indices start at 0. * * @group collection_funcs * @since 3.0.0 @@ -6108,8 +6297,10 @@ object functions { * df.select(exists(col("i"), _ % 2 === 0)) * }}} * - * @param column the input array column - * @param f col => predicate, the Boolean predicate to check the input column + * @param column + * the input array column + * @param f + * col => predicate, the Boolean predicate to check the input column * * @group collection_funcs * @since 3.0.0 @@ -6123,8 +6314,10 @@ object functions { * df.select(forall(col("i"), x => x % 2 === 0)) * }}} * - * @param column the input array column - * @param f col => predicate, the Boolean predicate to check the input column + * @param column + * the input array column + * @param f + * col => predicate, the Boolean predicate to check the input column * * @group collection_funcs * @since 3.0.0 @@ -6138,8 +6331,10 @@ object functions { * df.select(filter(col("s"), x => x % 2 === 0)) * }}} * - * @param column the input array column - * @param f col => predicate, the Boolean predicate to filter the input column + * @param column + * the input array column + * @param f + * col => predicate, the Boolean predicate to filter the input column * * @group collection_funcs * @since 3.0.0 @@ -6153,9 +6348,11 @@ object functions { * df.select(filter(col("s"), (x, i) => i % 2 === 0)) * }}} * - * @param column the input array column - * @param f (col, index) => predicate, the Boolean predicate to filter the input column - * given the index. Indices start at 0. + * @param column + * the input array column + * @param f + * (col, index) => predicate, the Boolean predicate to filter the input column given the + * index. Indices start at 0. * * @group collection_funcs * @since 3.0.0 @@ -6164,19 +6361,23 @@ object functions { Column.fn("filter", column, createLambda(f)) /** - * Applies a binary operator to an initial state and all elements in the array, - * and reduces this to a single state. The final state is converted into the final result - * by applying a finish function. + * Applies a binary operator to an initial state and all elements in the array, and reduces this + * to a single state. The final state is converted into the final result by applying a finish + * function. * {{{ * df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x, _ * 10)) * }}} * - * @param expr the input array column - * @param initialValue the initial value - * @param merge (combined_value, input_value) => combined_value, the merge function to merge - * an input value to the combined_value - * @param finish combined_value => final_value, the lambda function to convert the combined value - * of all inputs to final result + * @param expr + * the input array column + * @param initialValue + * the initial value + * @param merge + * (combined_value, input_value) => combined_value, the merge function to merge an input value + * to the combined_value + * @param finish + * combined_value => final_value, the lambda function to convert the combined value of all + * inputs to final result * * @group collection_funcs * @since 3.0.0 @@ -6189,16 +6390,19 @@ object functions { Column.fn("aggregate", expr, initialValue, createLambda(merge), createLambda(finish)) /** - * Applies a binary operator to an initial state and all elements in the array, - * and reduces this to a single state. + * Applies a binary operator to an initial state and all elements in the array, and reduces this + * to a single state. * {{{ * df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x)) * }}} * - * @param expr the input array column - * @param initialValue the initial value - * @param merge (combined_value, input_value) => combined_value, the merge function to merge - * an input value to the combined_value + * @param expr + * the input array column + * @param initialValue + * the initial value + * @param merge + * (combined_value, input_value) => combined_value, the merge function to merge an input value + * to the combined_value * @group collection_funcs * @since 3.0.0 */ @@ -6206,19 +6410,23 @@ object functions { aggregate(expr, initialValue, merge, c => c) /** - * Applies a binary operator to an initial state and all elements in the array, - * and reduces this to a single state. The final state is converted into the final result - * by applying a finish function. + * Applies a binary operator to an initial state and all elements in the array, and reduces this + * to a single state. The final state is converted into the final result by applying a finish + * function. * {{{ * df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x, _ * 10)) * }}} * - * @param expr the input array column - * @param initialValue the initial value - * @param merge (combined_value, input_value) => combined_value, the merge function to merge - * an input value to the combined_value - * @param finish combined_value => final_value, the lambda function to convert the combined value - * of all inputs to final result + * @param expr + * the input array column + * @param initialValue + * the initial value + * @param merge + * (combined_value, input_value) => combined_value, the merge function to merge an input value + * to the combined_value + * @param finish + * combined_value => final_value, the lambda function to convert the combined value of all + * inputs to final result * * @group collection_funcs * @since 3.5.0 @@ -6231,16 +6439,19 @@ object functions { Column.fn("reduce", expr, initialValue, createLambda(merge), createLambda(finish)) /** - * Applies a binary operator to an initial state and all elements in the array, - * and reduces this to a single state. + * Applies a binary operator to an initial state and all elements in the array, and reduces this + * to a single state. * {{{ * df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x)) * }}} * - * @param expr the input array column - * @param initialValue the initial value - * @param merge (combined_value, input_value) => combined_value, the merge function to merge - * an input value to the combined_value + * @param expr + * the input array column + * @param initialValue + * the initial value + * @param merge + * (combined_value, input_value) => combined_value, the merge function to merge an input value + * to the combined_value * @group collection_funcs * @since 3.5.0 */ @@ -6248,16 +6459,19 @@ object functions { reduce(expr, initialValue, merge, c => c) /** - * Merge two given arrays, element-wise, into a single array using a function. - * If one array is shorter, nulls are appended at the end to match the length of the longer - * array, before applying the function. + * Merge two given arrays, element-wise, into a single array using a function. If one array is + * shorter, nulls are appended at the end to match the length of the longer array, before + * applying the function. * {{{ * df.select(zip_with(df1("val1"), df1("val2"), (x, y) => x + y)) * }}} * - * @param left the left input array column - * @param right the right input array column - * @param f (lCol, rCol) => col, the lambda function to merge two input columns into one column + * @param left + * the left input array column + * @param right + * the right input array column + * @param f + * (lCol, rCol) => col, the lambda function to merge two input columns into one column * * @group collection_funcs * @since 3.0.0 @@ -6266,14 +6480,16 @@ object functions { Column.fn("zip_with", left, right, createLambda(f)) /** - * Applies a function to every key-value pair in a map and returns - * a map with the results of those applications as the new keys for the pairs. + * Applies a function to every key-value pair in a map and returns a map with the results of + * those applications as the new keys for the pairs. * {{{ * df.select(transform_keys(col("i"), (k, v) => k + v)) * }}} * - * @param expr the input map column - * @param f (key, value) => new_key, the lambda function to transform the key of input map column + * @param expr + * the input map column + * @param f + * (key, value) => new_key, the lambda function to transform the key of input map column * * @group collection_funcs * @since 3.0.0 @@ -6282,15 +6498,16 @@ object functions { Column.fn("transform_keys", expr, createLambda(f)) /** - * Applies a function to every key-value pair in a map and returns - * a map with the results of those applications as the new values for the pairs. + * Applies a function to every key-value pair in a map and returns a map with the results of + * those applications as the new values for the pairs. * {{{ * df.select(transform_values(col("i"), (k, v) => k + v)) * }}} * - * @param expr the input map column - * @param f (key, value) => new_value, the lambda function to transform the value of input map - * column + * @param expr + * the input map column + * @param f + * (key, value) => new_value, the lambda function to transform the value of input map column * * @group collection_funcs * @since 3.0.0 @@ -6304,8 +6521,10 @@ object functions { * df.select(map_filter(col("m"), (k, v) => k * 10 === v)) * }}} * - * @param expr the input map column - * @param f (key, value) => predicate, the Boolean predicate to filter the input map column + * @param expr + * the input map column + * @param f + * (key, value) => predicate, the Boolean predicate to filter the input map column * * @group collection_funcs * @since 3.0.0 @@ -6319,9 +6538,12 @@ object functions { * df.select(map_zip_with(df("m1"), df("m2"), (k, v1, v2) => k === v1 + v2)) * }}} * - * @param left the left input map column - * @param right the right input map column - * @param f (key, value1, value2) => new_value, the lambda function to merge the map values + * @param left + * the left input map column + * @param right + * the right input map column + * @param f + * (key, value1, value2) => new_value, the lambda function to merge the map values * * @group collection_funcs * @since 3.0.0 @@ -6330,9 +6552,9 @@ object functions { Column.fn("map_zip_with", left, right, createLambda(f)) /** - * Creates a new row for each element in the given array or map column. - * Uses the default column name `col` for elements in the array and - * `key` and `value` for elements in the map unless specified otherwise. + * Creates a new row for each element in the given array or map column. Uses the default column + * name `col` for elements in the array and `key` and `value` for elements in the map unless + * specified otherwise. * * @group generator_funcs * @since 1.3.0 @@ -6340,10 +6562,9 @@ object functions { def explode(e: Column): Column = Column.fn("explode", e) /** - * Creates a new row for each element in the given array or map column. - * Uses the default column name `col` for elements in the array and - * `key` and `value` for elements in the map unless specified otherwise. - * Unlike explode, if the array/map is null or empty then null is produced. + * Creates a new row for each element in the given array or map column. Uses the default column + * name `col` for elements in the array and `key` and `value` for elements in the map unless + * specified otherwise. Unlike explode, if the array/map is null or empty then null is produced. * * @group generator_funcs * @since 2.2.0 @@ -6351,9 +6572,9 @@ object functions { def explode_outer(e: Column): Column = Column.fn("explode_outer", e) /** - * Creates a new row for each element with position in the given array or map column. - * Uses the default column name `pos` for position, and `col` for elements in the array - * and `key` and `value` for elements in the map unless specified otherwise. + * Creates a new row for each element with position in the given array or map column. Uses the + * default column name `pos` for position, and `col` for elements in the array and `key` and + * `value` for elements in the map unless specified otherwise. * * @group generator_funcs * @since 2.1.0 @@ -6361,17 +6582,17 @@ object functions { def posexplode(e: Column): Column = Column.fn("posexplode", e) /** - * Creates a new row for each element with position in the given array or map column. - * Uses the default column name `pos` for position, and `col` for elements in the array - * and `key` and `value` for elements in the map unless specified otherwise. - * Unlike posexplode, if the array/map is null or empty then the row (null, null) is produced. + * Creates a new row for each element with position in the given array or map column. Uses the + * default column name `pos` for position, and `col` for elements in the array and `key` and + * `value` for elements in the map unless specified otherwise. Unlike posexplode, if the + * array/map is null or empty then the row (null, null) is produced. * * @group generator_funcs * @since 2.2.0 */ def posexplode_outer(e: Column): Column = Column.fn("posexplode_outer", e) - /** + /** * Creates a new row for each element in the given array of structs. * * @group generator_funcs @@ -6380,8 +6601,8 @@ object functions { def inline(e: Column): Column = Column.fn("inline", e) /** - * Creates a new row for each element in the given array of structs. - * Unlike inline, if the array is null or empty then null is produced for each nested column. + * Creates a new row for each element in the given array of structs. Unlike inline, if the array + * is null or empty then null is produced for each nested column. * * @group generator_funcs * @since 3.4.0 @@ -6415,14 +6636,15 @@ object functions { * (Scala-specific) Parses a column containing a JSON string into a `StructType` with the * specified schema. Returns `null`, in the case of an unparseable string. * - * @param e a string column containing JSON data. - * @param schema the schema to use when parsing the json string - * @param options options to control how the json is parsed. Accepts the same options as the - * json data source. - * See - * - * Data Source Option in the version you use. + * @param e + * a string column containing JSON data. + * @param schema + * the schema to use when parsing the json string + * @param options + * options to control how the json is parsed. Accepts the same options as the json data + * source. See Data + * Source Option in the version you use. * * @group json_funcs * @since 2.1.0 @@ -6434,17 +6656,18 @@ object functions { // scalastyle:off line.size.limit /** * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` - * as keys type, `StructType` or `ArrayType` with the specified schema. - * Returns `null`, in the case of an unparseable string. - * - * @param e a string column containing JSON data. - * @param schema the schema to use when parsing the json string - * @param options options to control how the json is parsed. accepts the same options and the - * json data source. - * See - * - * Data Source Option in the version you use. + * as keys type, `StructType` or `ArrayType` with the specified schema. Returns `null`, in the + * case of an unparseable string. + * + * @param e + * a string column containing JSON data. + * @param schema + * the schema to use when parsing the json string + * @param options + * options to control how the json is parsed. accepts the same options and the json data + * source. See Data + * Source Option in the version you use. * * @group json_funcs * @since 2.2.0 @@ -6459,14 +6682,15 @@ object functions { * (Java-specific) Parses a column containing a JSON string into a `StructType` with the * specified schema. Returns `null`, in the case of an unparseable string. * - * @param e a string column containing JSON data. - * @param schema the schema to use when parsing the json string - * @param options options to control how the json is parsed. accepts the same options and the - * json data source. - * See - * - * Data Source Option in the version you use. + * @param e + * a string column containing JSON data. + * @param schema + * the schema to use when parsing the json string + * @param options + * options to control how the json is parsed. accepts the same options and the json data + * source. See Data + * Source Option in the version you use. * * @group json_funcs * @since 2.1.0 @@ -6478,17 +6702,18 @@ object functions { // scalastyle:off line.size.limit /** * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` - * as keys type, `StructType` or `ArrayType` with the specified schema. - * Returns `null`, in the case of an unparseable string. - * - * @param e a string column containing JSON data. - * @param schema the schema to use when parsing the json string - * @param options options to control how the json is parsed. accepts the same options and the - * json data source. - * See - * - * Data Source Option in the version you use. + * as keys type, `StructType` or `ArrayType` with the specified schema. Returns `null`, in the + * case of an unparseable string. + * + * @param e + * a string column containing JSON data. + * @param schema + * the schema to use when parsing the json string + * @param options + * options to control how the json is parsed. accepts the same options and the json data + * source. See Data + * Source Option in the version you use. * * @group json_funcs * @since 2.2.0 @@ -6502,8 +6727,10 @@ object functions { * Parses a column containing a JSON string into a `StructType` with the specified schema. * Returns `null`, in the case of an unparseable string. * - * @param e a string column containing JSON data. - * @param schema the schema to use when parsing the json string + * @param e + * a string column containing JSON data. + * @param schema + * the schema to use when parsing the json string * * @group json_funcs * @since 2.1.0 @@ -6513,11 +6740,13 @@ object functions { /** * Parses a column containing a JSON string into a `MapType` with `StringType` as keys type, - * `StructType` or `ArrayType` with the specified schema. - * Returns `null`, in the case of an unparseable string. + * `StructType` or `ArrayType` with the specified schema. Returns `null`, in the case of an + * unparseable string. * - * @param e a string column containing JSON data. - * @param schema the schema to use when parsing the json string + * @param e + * a string column containing JSON data. + * @param schema + * the schema to use when parsing the json string * * @group json_funcs * @since 2.2.0 @@ -6528,17 +6757,18 @@ object functions { // scalastyle:off line.size.limit /** * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` - * as keys type, `StructType` or `ArrayType` with the specified schema. - * Returns `null`, in the case of an unparseable string. - * - * @param e a string column containing JSON data. - * @param schema the schema as a DDL-formatted string. - * @param options options to control how the json is parsed. accepts the same options and the - * json data source. - * See - * - * Data Source Option in the version you use. + * as keys type, `StructType` or `ArrayType` with the specified schema. Returns `null`, in the + * case of an unparseable string. + * + * @param e + * a string column containing JSON data. + * @param schema + * the schema as a DDL-formatted string. + * @param options + * options to control how the json is parsed. accepts the same options and the json data + * source. See Data + * Source Option in the version you use. * * @group json_funcs * @since 2.1.0 @@ -6551,17 +6781,18 @@ object functions { // scalastyle:off line.size.limit /** * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` - * as keys type, `StructType` or `ArrayType` with the specified schema. - * Returns `null`, in the case of an unparseable string. - * - * @param e a string column containing JSON data. - * @param schema the schema as a DDL-formatted string. - * @param options options to control how the json is parsed. accepts the same options and the - * json data source. - * See - * - * Data Source Option in the version you use. + * as keys type, `StructType` or `ArrayType` with the specified schema. Returns `null`, in the + * case of an unparseable string. + * + * @param e + * a string column containing JSON data. + * @param schema + * the schema as a DDL-formatted string. + * @param options + * options to control how the json is parsed. accepts the same options and the json data + * source. See Data + * Source Option in the version you use. * * @group json_funcs * @since 2.3.0 @@ -6573,11 +6804,13 @@ object functions { /** * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` - * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. - * Returns `null`, in the case of an unparseable string. + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. Returns + * `null`, in the case of an unparseable string. * - * @param e a string column containing JSON data. - * @param schema the schema to use when parsing the json string + * @param e + * a string column containing JSON data. + * @param schema + * the schema to use when parsing the json string * * @group json_funcs * @since 2.4.0 @@ -6589,17 +6822,18 @@ object functions { // scalastyle:off line.size.limit /** * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` - * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. - * Returns `null`, in the case of an unparseable string. - * - * @param e a string column containing JSON data. - * @param schema the schema to use when parsing the json string - * @param options options to control how the json is parsed. accepts the same options and the - * json data source. - * See - * - * Data Source Option in the version you use. + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. Returns + * `null`, in the case of an unparseable string. + * + * @param e + * a string column containing JSON data. + * @param schema + * the schema to use when parsing the json string + * @param options + * options to control how the json is parsed. accepts the same options and the json data + * source. See Data + * Source Option in the version you use. * * @group json_funcs * @since 2.4.0 @@ -6620,7 +6854,8 @@ object functions { * Parses a JSON string and constructs a Variant value. Returns null if the input string is not * a valid JSON value. * - * @param json a string column that contains JSON data. + * @param json + * a string column that contains JSON data. * * @group variant_funcs * @since 4.0.0 @@ -6637,6 +6872,18 @@ object functions { */ def parse_json(json: Column): Column = Column.fn("parse_json", json) + /** + * Converts a column containing nested inputs (array/map/struct) into a variants where maps and + * structs are converted to variant objects which are unordered unlike SQL structs. Input maps + * can only have string keys. + * + * @param col + * a column with a nested schema or column name. + * @group variant_funcs + * @since 4.0.0 + */ + def to_variant_object(col: Column): Column = Column.fn("to_variant_object", col) + /** * Check if a variant value is a variant null. Returns true if and only if the input is a * variant null and false otherwise (including in the case of SQL NULL). @@ -6705,7 +6952,8 @@ object functions { /** * Parses a JSON string and infers its schema in DDL format. * - * @param json a JSON string. + * @param json + * a JSON string. * * @group json_funcs * @since 2.4.0 @@ -6715,7 +6963,8 @@ object functions { /** * Parses a JSON string and infers its schema in DDL format. * - * @param json a foldable string column containing a JSON string. + * @param json + * a foldable string column containing a JSON string. * * @group json_funcs * @since 2.4.0 @@ -6726,14 +6975,15 @@ object functions { /** * Parses a JSON string and infers its schema in DDL format using options. * - * @param json a foldable string column containing JSON data. - * @param options options to control how the json is parsed. accepts the same options and the - * json data source. - * See - * - * Data Source Option in the version you use. - * @return a column with string literal containing schema in DDL format. + * @param json + * a foldable string column containing JSON data. + * @param options + * options to control how the json is parsed. accepts the same options and the json data + * source. See Data + * Source Option in the version you use. + * @return + * a column with string literal containing schema in DDL format. * * @group json_funcs * @since 3.0.0 @@ -6743,8 +6993,8 @@ object functions { Column.fnWithOptions("schema_of_json", options.asScala.iterator, json) /** - * Returns the number of elements in the outermost JSON array. `NULL` is returned in case of - * any other valid JSON string, `NULL` or an invalid JSON. + * Returns the number of elements in the outermost JSON array. `NULL` is returned in case of any + * other valid JSON string, `NULL` or an invalid JSON. * * @group json_funcs * @since 3.5.0 @@ -6753,8 +7003,8 @@ object functions { /** * Returns all the keys of the outermost JSON object as an array. If a valid JSON object is - * given, all the keys of the outermost object will be returned as an array. If it is any - * other valid JSON string, an invalid JSON string or an empty string, the function returns null. + * given, all the keys of the outermost object will be returned as an array. If it is any other + * valid JSON string, an invalid JSON string or an empty string, the function returns null. * * @group json_funcs * @since 3.5.0 @@ -6763,19 +7013,18 @@ object functions { // scalastyle:off line.size.limit /** - * (Scala-specific) Converts a column containing a `StructType`, `ArrayType` or - * a `MapType` into a JSON string with the specified schema. - * Throws an exception, in the case of an unsupported type. + * (Scala-specific) Converts a column containing a `StructType`, `ArrayType` or a `MapType` into + * a JSON string with the specified schema. Throws an exception, in the case of an unsupported + * type. * - * @param e a column containing a struct, an array or a map. - * @param options options to control how the struct column is converted into a json string. - * accepts the same options and the json data source. - * See - * - * Data Source Option in the version you use. - * Additionally the function supports the `pretty` option which enables - * pretty JSON generation. + * @param e + * a column containing a struct, an array or a map. + * @param options + * options to control how the struct column is converted into a json string. accepts the same + * options and the json data source. See Data + * Source Option in the version you use. Additionally the function supports the `pretty` + * option which enables pretty JSON generation. * * @group json_funcs * @since 2.1.0 @@ -6786,19 +7035,18 @@ object functions { // scalastyle:off line.size.limit /** - * (Java-specific) Converts a column containing a `StructType`, `ArrayType` or - * a `MapType` into a JSON string with the specified schema. - * Throws an exception, in the case of an unsupported type. + * (Java-specific) Converts a column containing a `StructType`, `ArrayType` or a `MapType` into + * a JSON string with the specified schema. Throws an exception, in the case of an unsupported + * type. * - * @param e a column containing a struct, an array or a map. - * @param options options to control how the struct column is converted into a json string. - * accepts the same options and the json data source. - * See - * - * Data Source Option in the version you use. - * Additionally the function supports the `pretty` option which enables - * pretty JSON generation. + * @param e + * a column containing a struct, an array or a map. + * @param options + * options to control how the struct column is converted into a json string. accepts the same + * options and the json data source. See Data + * Source Option in the version you use. Additionally the function supports the `pretty` + * option which enables pretty JSON generation. * * @group json_funcs * @since 2.1.0 @@ -6808,11 +7056,11 @@ object functions { to_json(e, options.asScala.toMap) /** - * Converts a column containing a `StructType`, `ArrayType` or - * a `MapType` into a JSON string with the specified schema. - * Throws an exception, in the case of an unsupported type. + * Converts a column containing a `StructType`, `ArrayType` or a `MapType` into a JSON string + * with the specified schema. Throws an exception, in the case of an unsupported type. * - * @param e a column containing a struct, an array or a map. + * @param e + * a column containing a struct, an array or a map. * * @group json_funcs * @since 2.1.0 @@ -6822,10 +7070,11 @@ object functions { /** * Masks the given string value. The function replaces characters with 'X' or 'x', and numbers - * with 'n'. - * This can be useful for creating copies of tables with sensitive information removed. + * with 'n'. This can be useful for creating copies of tables with sensitive information + * removed. * - * @param input string value to mask. Supported types: STRING, VARCHAR, CHAR + * @param input + * string value to mask. Supported types: STRING, VARCHAR, CHAR * * @group string_funcs * @since 3.5.0 @@ -6834,8 +7083,8 @@ object functions { /** * Masks the given string value. The function replaces upper-case characters with specific - * character, lower-case characters with 'x', and numbers with 'n'. - * This can be useful for creating copies of tables with sensitive information removed. + * character, lower-case characters with 'x', and numbers with 'n'. This can be useful for + * creating copies of tables with sensitive information removed. * * @param input * string value to mask. Supported types: STRING, VARCHAR, CHAR @@ -6850,8 +7099,8 @@ object functions { /** * Masks the given string value. The function replaces upper-case and lower-case characters with - * the characters specified respectively, and numbers with 'n'. - * This can be useful for creating copies of tables with sensitive information removed. + * the characters specified respectively, and numbers with 'n'. This can be useful for creating + * copies of tables with sensitive information removed. * * @param input * string value to mask. Supported types: STRING, VARCHAR, CHAR @@ -6868,8 +7117,8 @@ object functions { /** * Masks the given string value. The function replaces upper-case, lower-case characters and - * numbers with the characters specified respectively. - * This can be useful for creating copies of tables with sensitive information removed. + * numbers with the characters specified respectively. This can be useful for creating copies of + * tables with sensitive information removed. * * @param input * string value to mask. Supported types: STRING, VARCHAR, CHAR @@ -6916,8 +7165,8 @@ object functions { * Returns length of array or map. * * This function returns -1 for null input only if spark.sql.ansi.enabled is false and - * spark.sql.legacy.sizeOfNull is true. Otherwise, it returns null for null input. - * With the default settings, the function returns null for null input. + * spark.sql.legacy.sizeOfNull is true. Otherwise, it returns null for null input. With the + * default settings, the function returns null for null input. * * @group collection_funcs * @since 1.5.0 @@ -6928,8 +7177,8 @@ object functions { * Returns length of array or map. This is an alias of `size` function. * * This function returns -1 for null input only if spark.sql.ansi.enabled is false and - * spark.sql.legacy.sizeOfNull is true. Otherwise, it returns null for null input. - * With the default settings, the function returns null for null input. + * spark.sql.legacy.sizeOfNull is true. Otherwise, it returns null for null input. With the + * default settings, the function returns null for null input. * * @group collection_funcs * @since 3.5.0 @@ -6937,9 +7186,9 @@ object functions { def cardinality(e: Column): Column = Column.fn("cardinality", e) /** - * Sorts the input array for the given column in ascending order, - * according to the natural ordering of the array elements. - * Null elements will be placed at the beginning of the returned array. + * Sorts the input array for the given column in ascending order, according to the natural + * ordering of the array elements. Null elements will be placed at the beginning of the returned + * array. * * @group array_funcs * @since 1.5.0 @@ -6947,11 +7196,10 @@ object functions { def sort_array(e: Column): Column = sort_array(e, asc = true) /** - * Sorts the input array for the given column in ascending or descending order, - * according to the natural ordering of the array elements. NaN is greater than any non-NaN - * elements for double/float type. Null elements will be placed at the beginning of the returned - * array in ascending order or - * at the end of the returned array in descending order. + * Sorts the input array for the given column in ascending or descending order, according to the + * natural ordering of the array elements. NaN is greater than any non-NaN elements for + * double/float type. Null elements will be placed at the beginning of the returned array in + * ascending order or at the end of the returned array in descending order. * * @group array_funcs * @since 1.5.0 @@ -6987,8 +7235,9 @@ object functions { /** * Aggregate function: returns a list of objects with duplicates. * - * @note The function is non-deterministic because the order of collected results depends - * on the order of the rows which may be non-deterministic after a shuffle. + * @note + * The function is non-deterministic because the order of collected results depends on the + * order of the rows which may be non-deterministic after a shuffle. * @group agg_funcs * @since 3.5.0 */ @@ -6997,7 +7246,8 @@ object functions { /** * Returns a random permutation of the given array. * - * @note The function is non-deterministic. + * @note + * The function is non-deterministic. * * @group array_funcs * @since 2.4.0 @@ -7012,8 +7262,8 @@ object functions { def reverse(e: Column): Column = Column.fn("reverse", e) /** - * Creates a single array from an array of arrays. If a structure of nested arrays is deeper than - * two levels, only one level of nesting is removed. + * Creates a single array from an array of arrays. If a structure of nested arrays is deeper + * than two levels, only one level of nesting is removed. * @group array_funcs * @since 2.4.0 */ @@ -7029,8 +7279,8 @@ object functions { Column.fn("sequence", start, stop, step) /** - * Generate a sequence of integers from start to stop, - * incrementing by 1 if start is less than or equal to stop, otherwise -1. + * Generate a sequence of integers from start to stop, incrementing by 1 if start is less than + * or equal to stop, otherwise -1. * * @group array_funcs * @since 2.4.0 @@ -7038,8 +7288,8 @@ object functions { def sequence(start: Column, stop: Column): Column = Column.fn("sequence", start, stop) /** - * Creates an array containing the left argument repeated the number of times given by the - * right argument. + * Creates an array containing the left argument repeated the number of times given by the right + * argument. * * @group array_funcs * @since 2.4.0 @@ -7047,8 +7297,8 @@ object functions { def array_repeat(left: Column, right: Column): Column = Column.fn("array_repeat", left, right) /** - * Creates an array containing the left argument repeated the number of times given by the - * right argument. + * Creates an array containing the left argument repeated the number of times given by the right + * argument. * * @group array_funcs * @since 2.4.0 @@ -7113,14 +7363,15 @@ object functions { * Parses a column containing a CSV string into a `StructType` with the specified schema. * Returns `null`, in the case of an unparseable string. * - * @param e a string column containing CSV data. - * @param schema the schema to use when parsing the CSV string - * @param options options to control how the CSV is parsed. accepts the same options and the - * CSV data source. - * See - * - * Data Source Option in the version you use. + * @param e + * a string column containing CSV data. + * @param schema + * the schema to use when parsing the CSV string + * @param options + * options to control how the CSV is parsed. accepts the same options and the CSV data source. + * See Data + * Source Option in the version you use. * * @group csv_funcs * @since 3.0.0 @@ -7131,17 +7382,18 @@ object functions { // scalastyle:off line.size.limit /** - * (Java-specific) Parses a column containing a CSV string into a `StructType` - * with the specified schema. Returns `null`, in the case of an unparseable string. + * (Java-specific) Parses a column containing a CSV string into a `StructType` with the + * specified schema. Returns `null`, in the case of an unparseable string. * - * @param e a string column containing CSV data. - * @param schema the schema to use when parsing the CSV string - * @param options options to control how the CSV is parsed. accepts the same options and the - * CSV data source. - * See - * - * Data Source Option in the version you use. + * @param e + * a string column containing CSV data. + * @param schema + * the schema to use when parsing the CSV string + * @param options + * options to control how the CSV is parsed. accepts the same options and the CSV data source. + * See Data + * Source Option in the version you use. * * @group csv_funcs * @since 3.0.0 @@ -7156,7 +7408,8 @@ object functions { /** * Parses a CSV string and infers its schema in DDL format. * - * @param csv a CSV string. + * @param csv + * a CSV string. * * @group csv_funcs * @since 3.0.0 @@ -7166,7 +7419,8 @@ object functions { /** * Parses a CSV string and infers its schema in DDL format. * - * @param csv a foldable string column containing a CSV string. + * @param csv + * a foldable string column containing a CSV string. * * @group csv_funcs * @since 3.0.0 @@ -7177,14 +7431,15 @@ object functions { /** * Parses a CSV string and infers its schema in DDL format using options. * - * @param csv a foldable string column containing a CSV string. - * @param options options to control how the CSV is parsed. accepts the same options and the - * CSV data source. - * See - * - * Data Source Option in the version you use. - * @return a column with string literal containing schema in DDL format. + * @param csv + * a foldable string column containing a CSV string. + * @param options + * options to control how the CSV is parsed. accepts the same options and the CSV data source. + * See Data + * Source Option in the version you use. + * @return + * a column with string literal containing schema in DDL format. * * @group csv_funcs * @since 3.0.0 @@ -7195,16 +7450,16 @@ object functions { // scalastyle:off line.size.limit /** - * (Java-specific) Converts a column containing a `StructType` into a CSV string with - * the specified schema. Throws an exception, in the case of an unsupported type. + * (Java-specific) Converts a column containing a `StructType` into a CSV string with the + * specified schema. Throws an exception, in the case of an unsupported type. * - * @param e a column containing a struct. - * @param options options to control how the struct column is converted into a CSV string. - * It accepts the same options and the CSV data source. - * See - * - * Data Source Option in the version you use. + * @param e + * a column containing a struct. + * @param options + * options to control how the struct column is converted into a CSV string. It accepts the + * same options and the CSV data source. See Data + * Source Option in the version you use. * * @group csv_funcs * @since 3.0.0 @@ -7217,7 +7472,8 @@ object functions { * Converts a column containing a `StructType` into a CSV string with the specified schema. * Throws an exception, in the case of an unsupported type. * - * @param e a column containing a struct. + * @param e + * a column containing a struct. * * @group csv_funcs * @since 3.0.0 @@ -7226,17 +7482,18 @@ object functions { // scalastyle:off line.size.limit /** - * Parses a column containing a XML string into the data type corresponding to the specified schema. - * Returns `null`, in the case of an unparseable string. - * - * @param e a string column containing XML data. - * @param schema the schema to use when parsing the XML string - * @param options options to control how the XML is parsed. accepts the same options and the - * XML data source. - * See - * - * Data Source Option in the version you use. + * Parses a column containing a XML string into the data type corresponding to the specified + * schema. Returns `null`, in the case of an unparseable string. + * + * @param e + * a string column containing XML data. + * @param schema + * the schema to use when parsing the XML string + * @param options + * options to control how the XML is parsed. accepts the same options and the XML data source. + * See Data + * Source Option in the version you use. * @group xml_funcs * @since 4.0.0 */ @@ -7246,18 +7503,18 @@ object functions { // scalastyle:off line.size.limit /** - * (Java-specific) Parses a column containing a XML string into a `StructType` - * with the specified schema. - * Returns `null`, in the case of an unparseable string. + * (Java-specific) Parses a column containing a XML string into a `StructType` with the + * specified schema. Returns `null`, in the case of an unparseable string. * - * @param e a string column containing XML data. - * @param schema the schema as a DDL-formatted string. - * @param options options to control how the XML is parsed. accepts the same options and the - * xml data source. - * See - * - * Data Source Option in the version you use. + * @param e + * a string column containing XML data. + * @param schema + * the schema as a DDL-formatted string. + * @param options + * options to control how the XML is parsed. accepts the same options and the xml data source. + * See Data + * Source Option in the version you use. * @group xml_funcs * @since 4.0.0 */ @@ -7268,11 +7525,13 @@ object functions { // scalastyle:off line.size.limit /** - * (Java-specific) Parses a column containing a XML string into a `StructType` - * with the specified schema. Returns `null`, in the case of an unparseable string. + * (Java-specific) Parses a column containing a XML string into a `StructType` with the + * specified schema. Returns `null`, in the case of an unparseable string. * - * @param e a string column containing XML data. - * @param schema the schema to use when parsing the XML string + * @param e + * a string column containing XML data. + * @param schema + * the schema to use when parsing the XML string * @group xml_funcs * @since 4.0.0 */ @@ -7283,17 +7542,18 @@ object functions { // scalastyle:off line.size.limit /** - * (Java-specific) Parses a column containing a XML string into a `StructType` - * with the specified schema. Returns `null`, in the case of an unparseable string. - * - * @param e a string column containing XML data. - * @param schema the schema to use when parsing the XML string - * @param options options to control how the XML is parsed. accepts the same options and the - * XML data source. - * See - * - * Data Source Option in the version you use. + * (Java-specific) Parses a column containing a XML string into a `StructType` with the + * specified schema. Returns `null`, in the case of an unparseable string. + * + * @param e + * a string column containing XML data. + * @param schema + * the schema to use when parsing the XML string + * @param options + * options to control how the XML is parsed. accepts the same options and the XML data source. + * See Data + * Source Option in the version you use. * @group xml_funcs * @since 4.0.0 */ @@ -7302,13 +7562,14 @@ object functions { from_xml(e, schema, options.asScala.iterator) /** - * Parses a column containing a XML string into the data type - * corresponding to the specified schema. - * Returns `null`, in the case of an unparseable string. + * Parses a column containing a XML string into the data type corresponding to the specified + * schema. Returns `null`, in the case of an unparseable string. + * + * @param e + * a string column containing XML data. + * @param schema + * the schema to use when parsing the XML string * - * @param e a string column containing XML data. - * @param schema the schema to use when parsing the XML string - * @group xml_funcs * @since 4.0.0 */ @@ -7322,7 +7583,8 @@ object functions { /** * Parses a XML string and infers its schema in DDL format. * - * @param xml a XML string. + * @param xml + * a XML string. * @group xml_funcs * @since 4.0.0 */ @@ -7331,7 +7593,8 @@ object functions { /** * Parses a XML string and infers its schema in DDL format. * - * @param xml a foldable string column containing a XML string. + * @param xml + * a foldable string column containing a XML string. * @group xml_funcs * @since 4.0.0 */ @@ -7342,14 +7605,15 @@ object functions { /** * Parses a XML string and infers its schema in DDL format using options. * - * @param xml a foldable string column containing XML data. - * @param options options to control how the xml is parsed. accepts the same options and the - * XML data source. - * See - * - * Data Source Option in the version you use. - * @return a column with string literal containing schema in DDL format. + * @param xml + * a foldable string column containing XML data. + * @param options + * options to control how the xml is parsed. accepts the same options and the XML data source. + * See Data + * Source Option in the version you use. + * @return + * a column with string literal containing schema in DDL format. * @group xml_funcs * @since 4.0.0 */ @@ -7360,16 +7624,16 @@ object functions { // scalastyle:off line.size.limit /** - * (Java-specific) Converts a column containing a `StructType` into a XML string with - * the specified schema. Throws an exception, in the case of an unsupported type. + * (Java-specific) Converts a column containing a `StructType` into a XML string with the + * specified schema. Throws an exception, in the case of an unsupported type. * - * @param e a column containing a struct. - * @param options options to control how the struct column is converted into a XML string. - * It accepts the same options as the XML data source. - * See - * - * Data Source Option in the version you use. + * @param e + * a column containing a struct. + * @param options + * options to control how the struct column is converted into a XML string. It accepts the + * same options as the XML data source. See Data + * Source Option in the version you use. * @group xml_funcs * @since 4.0.0 */ @@ -7381,7 +7645,8 @@ object functions { * Converts a column containing a `StructType` into a XML string with the specified schema. * Throws an exception, in the case of an unsupported type. * - * @param e a column containing a struct. + * @param e + * a column containing a struct. * @group xml_funcs * @since 4.0.0 */ @@ -7430,8 +7695,8 @@ object functions { Column.fn("xpath_boolean", xml, path) /** - * Returns a double value, the value zero if no match is found, - * or NaN if a match is found but the value is non-numeric. + * Returns a double value, the value zero if no match is found, or NaN if a match is found but + * the value is non-numeric. * * @group xml_funcs * @since 3.5.0 @@ -7440,8 +7705,8 @@ object functions { Column.fn("xpath_double", xml, path) /** - * Returns a double value, the value zero if no match is found, - * or NaN if a match is found but the value is non-numeric. + * Returns a double value, the value zero if no match is found, or NaN if a match is found but + * the value is non-numeric. * * @group xml_funcs * @since 3.5.0 @@ -7450,8 +7715,8 @@ object functions { Column.fn("xpath_number", xml, path) /** - * Returns a float value, the value zero if no match is found, - * or NaN if a match is found but the value is non-numeric. + * Returns a float value, the value zero if no match is found, or NaN if a match is found but + * the value is non-numeric. * * @group xml_funcs * @since 3.5.0 @@ -7460,8 +7725,8 @@ object functions { Column.fn("xpath_float", xml, path) /** - * Returns an integer value, or the value zero if no match is found, - * or a match is found but the value is non-numeric. + * Returns an integer value, or the value zero if no match is found, or a match is found but the + * value is non-numeric. * * @group xml_funcs * @since 3.5.0 @@ -7470,8 +7735,8 @@ object functions { Column.fn("xpath_int", xml, path) /** - * Returns a long integer value, or the value zero if no match is found, - * or a match is found but the value is non-numeric. + * Returns a long integer value, or the value zero if no match is found, or a match is found but + * the value is non-numeric. * * @group xml_funcs * @since 3.5.0 @@ -7480,8 +7745,8 @@ object functions { Column.fn("xpath_long", xml, path) /** - * Returns a short integer value, or the value zero if no match is found, - * or a match is found but the value is non-numeric. + * Returns a short integer value, or the value zero if no match is found, or a match is found + * but the value is non-numeric. * * @group xml_funcs * @since 3.5.0 @@ -7507,13 +7772,16 @@ object functions { def hours(e: Column): Column = partitioning.hours(e) /** - * Converts the timestamp without time zone `sourceTs` - * from the `sourceTz` time zone to `targetTz`. + * Converts the timestamp without time zone `sourceTs` from the `sourceTz` time zone to + * `targetTz`. * - * @param sourceTz the time zone for the input timestamp. If it is missed, - * the current session time zone is used as the source time zone. - * @param targetTz the time zone to which the input timestamp should be converted. - * @param sourceTs a timestamp without time zone. + * @param sourceTz + * the time zone for the input timestamp. If it is missed, the current session time zone is + * used as the source time zone. + * @param targetTz + * the time zone to which the input timestamp should be converted. + * @param sourceTs + * a timestamp without time zone. * @group datetime_funcs * @since 3.5.0 */ @@ -7521,11 +7789,12 @@ object functions { Column.fn("convert_timezone", sourceTz, targetTz, sourceTs) /** - * Converts the timestamp without time zone `sourceTs` - * from the current time zone to `targetTz`. + * Converts the timestamp without time zone `sourceTs` from the current time zone to `targetTz`. * - * @param targetTz the time zone to which the input timestamp should be converted. - * @param sourceTs a timestamp without time zone. + * @param targetTz + * the time zone to which the input timestamp should be converted. + * @param sourceTs + * a timestamp without time zone. * @group datetime_funcs * @since 3.5.0 */ @@ -7818,8 +8087,8 @@ object functions { def isnotnull(col: Column): Column = Column.fn("isnotnull", col) /** - * Returns same result as the EQUAL(=) operator for non-null operands, - * but returns true if both are null, false if one of the them is null. + * Returns same result as the EQUAL(=) operator for non-null operands, but returns true if both + * are null, false if one of the them is null. * * @group predicate_funcs * @since 3.5.0 @@ -7908,15 +8177,15 @@ object functions { |}""".stripMargin) } - */ + */ ////////////////////////////////////////////////////////////////////////////////////////////// // Scala UDF functions ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Obtains a `UserDefinedFunction` that wraps the given `Aggregator` - * so that it may be used with untyped Data Frames. + * Obtains a `UserDefinedFunction` that wraps the given `Aggregator` so that it may be used with + * untyped Data Frames. * {{{ * val agg = // Aggregator[IN, BUF, OUT] * @@ -7928,24 +8197,30 @@ object functions { * spark.udf.register("myAggName", udaf(agg)) * }}} * - * @tparam IN the aggregator input type - * @tparam BUF the aggregating buffer type - * @tparam OUT the finalized output type + * @tparam IN + * the aggregator input type + * @tparam BUF + * the aggregating buffer type + * @tparam OUT + * the finalized output type * - * @param agg the typed Aggregator + * @param agg + * the typed Aggregator * - * @return a UserDefinedFunction that can be used as an aggregating expression. + * @return + * a UserDefinedFunction that can be used as an aggregating expression. * * @group udf_funcs - * @note The input encoder is inferred from the input type IN. + * @note + * The input encoder is inferred from the input type IN. */ def udaf[IN: TypeTag, BUF, OUT](agg: Aggregator[IN, BUF, OUT]): UserDefinedFunction = { udaf(agg, ScalaReflection.encoderFor[IN]) } /** - * Obtains a `UserDefinedFunction` that wraps the given `Aggregator` - * so that it may be used with untyped Data Frames. + * Obtains a `UserDefinedFunction` that wraps the given `Aggregator` so that it may be used with + * untyped Data Frames. * {{{ * Aggregator agg = // custom Aggregator * Encoder enc = // input encoder @@ -7958,18 +8233,24 @@ object functions { * spark.udf.register("myAggName", udaf(agg, enc)) * }}} * - * @tparam IN the aggregator input type - * @tparam BUF the aggregating buffer type - * @tparam OUT the finalized output type + * @tparam IN + * the aggregator input type + * @tparam BUF + * the aggregating buffer type + * @tparam OUT + * the finalized output type * - * @param agg the typed Aggregator - * @param inputEncoder a specific input encoder to use + * @param agg + * the typed Aggregator + * @param inputEncoder + * a specific input encoder to use * - * @return a UserDefinedFunction that can be used as an aggregating expression + * @return + * a UserDefinedFunction that can be used as an aggregating expression * * @group udf_funcs - * @note This overloading takes an explicit input encoder, to support UDAF - * declarations in Java. + * @note + * This overloading takes an explicit input encoder, to support UDAF declarations in Java. */ def udaf[IN, BUF, OUT]( agg: Aggregator[IN, BUF, OUT], @@ -7978,10 +8259,10 @@ object functions { } /** - * Defines a Scala closure of 0 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the Scala closure's - * signature. By default the returned UDF is deterministic. To change it to - * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 0 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -7991,10 +8272,10 @@ object functions { } /** - * Defines a Scala closure of 1 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the Scala closure's - * signature. By default the returned UDF is deterministic. To change it to - * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 1 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -8004,120 +8285,242 @@ object functions { } /** - * Defines a Scala closure of 2 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the Scala closure's - * signature. By default the returned UDF is deterministic. To change it to - * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 2 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ - def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { - SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]]) + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag]( + f: Function2[A1, A2, RT]): UserDefinedFunction = { + SparkUserDefinedFunction( + f, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]]) } /** - * Defines a Scala closure of 3 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the Scala closure's - * signature. By default the returned UDF is deterministic. To change it to - * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 3 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ - def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { - SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]]) + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag]( + f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { + SparkUserDefinedFunction( + f, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]]) } /** - * Defines a Scala closure of 4 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the Scala closure's - * signature. By default the returned UDF is deterministic. To change it to - * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 4 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ - def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { - SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]]) + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag]( + f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { + SparkUserDefinedFunction( + f, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]]) } /** - * Defines a Scala closure of 5 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the Scala closure's - * signature. By default the returned UDF is deterministic. To change it to - * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 5 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ - def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { - SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]]) + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag]( + f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { + SparkUserDefinedFunction( + f, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]]) } /** - * Defines a Scala closure of 6 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the Scala closure's - * signature. By default the returned UDF is deterministic. To change it to - * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 6 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ - def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { - SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]]) + def udf[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { + SparkUserDefinedFunction( + f, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]]) } /** - * Defines a Scala closure of 7 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the Scala closure's - * signature. By default the returned UDF is deterministic. To change it to - * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 7 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ - def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { - SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]]) + def udf[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { + SparkUserDefinedFunction( + f, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]]) } /** - * Defines a Scala closure of 8 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the Scala closure's - * signature. By default the returned UDF is deterministic. To change it to - * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 8 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ - def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { - SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]]) + def udf[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { + SparkUserDefinedFunction( + f, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]], + implicitly[TypeTag[A8]]) } /** - * Defines a Scala closure of 9 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the Scala closure's - * signature. By default the returned UDF is deterministic. To change it to - * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 9 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ - def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { - SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]]) + def udf[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { + SparkUserDefinedFunction( + f, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]], + implicitly[TypeTag[A8]], + implicitly[TypeTag[A9]]) } /** - * Defines a Scala closure of 10 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the Scala closure's - * signature. By default the returned UDF is deterministic. To change it to - * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 10 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 */ - def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { - SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]]) + def udf[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag]( + f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { + SparkUserDefinedFunction( + f, + implicitly[TypeTag[RT]], + implicitly[TypeTag[A1]], + implicitly[TypeTag[A2]], + implicitly[TypeTag[A3]], + implicitly[TypeTag[A4]], + implicitly[TypeTag[A5]], + implicitly[TypeTag[A6]], + implicitly[TypeTag[A7]], + implicitly[TypeTag[A8]], + implicitly[TypeTag[A9]], + implicitly[TypeTag[A10]]) } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -8125,10 +8528,10 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Defines a Java UDF0 instance as user-defined function (UDF). - * The caller must specify the output data type, and there is no automatic input type coercion. - * By default the returned UDF is deterministic. To change it to nondeterministic, call the - * API `UserDefinedFunction.asNondeterministic()`. + * Defines a Java UDF0 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 2.3.0 @@ -8138,10 +8541,10 @@ object functions { } /** - * Defines a Java UDF1 instance as user-defined function (UDF). - * The caller must specify the output data type, and there is no automatic input type coercion. - * By default the returned UDF is deterministic. To change it to nondeterministic, call the - * API `UserDefinedFunction.asNondeterministic()`. + * Defines a Java UDF1 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 2.3.0 @@ -8151,10 +8554,10 @@ object functions { } /** - * Defines a Java UDF2 instance as user-defined function (UDF). - * The caller must specify the output data type, and there is no automatic input type coercion. - * By default the returned UDF is deterministic. To change it to nondeterministic, call the - * API `UserDefinedFunction.asNondeterministic()`. + * Defines a Java UDF2 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 2.3.0 @@ -8164,10 +8567,10 @@ object functions { } /** - * Defines a Java UDF3 instance as user-defined function (UDF). - * The caller must specify the output data type, and there is no automatic input type coercion. - * By default the returned UDF is deterministic. To change it to nondeterministic, call the - * API `UserDefinedFunction.asNondeterministic()`. + * Defines a Java UDF3 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 2.3.0 @@ -8177,10 +8580,10 @@ object functions { } /** - * Defines a Java UDF4 instance as user-defined function (UDF). - * The caller must specify the output data type, and there is no automatic input type coercion. - * By default the returned UDF is deterministic. To change it to nondeterministic, call the - * API `UserDefinedFunction.asNondeterministic()`. + * Defines a Java UDF4 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 2.3.0 @@ -8190,10 +8593,10 @@ object functions { } /** - * Defines a Java UDF5 instance as user-defined function (UDF). - * The caller must specify the output data type, and there is no automatic input type coercion. - * By default the returned UDF is deterministic. To change it to nondeterministic, call the - * API `UserDefinedFunction.asNondeterministic()`. + * Defines a Java UDF5 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 2.3.0 @@ -8203,10 +8606,10 @@ object functions { } /** - * Defines a Java UDF6 instance as user-defined function (UDF). - * The caller must specify the output data type, and there is no automatic input type coercion. - * By default the returned UDF is deterministic. To change it to nondeterministic, call the - * API `UserDefinedFunction.asNondeterministic()`. + * Defines a Java UDF6 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 2.3.0 @@ -8216,10 +8619,10 @@ object functions { } /** - * Defines a Java UDF7 instance as user-defined function (UDF). - * The caller must specify the output data type, and there is no automatic input type coercion. - * By default the returned UDF is deterministic. To change it to nondeterministic, call the - * API `UserDefinedFunction.asNondeterministic()`. + * Defines a Java UDF7 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 2.3.0 @@ -8229,10 +8632,10 @@ object functions { } /** - * Defines a Java UDF8 instance as user-defined function (UDF). - * The caller must specify the output data type, and there is no automatic input type coercion. - * By default the returned UDF is deterministic. To change it to nondeterministic, call the - * API `UserDefinedFunction.asNondeterministic()`. + * Defines a Java UDF8 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 2.3.0 @@ -8242,10 +8645,10 @@ object functions { } /** - * Defines a Java UDF9 instance as user-defined function (UDF). - * The caller must specify the output data type, and there is no automatic input type coercion. - * By default the returned UDF is deterministic. To change it to nondeterministic, call the - * API `UserDefinedFunction.asNondeterministic()`. + * Defines a Java UDF9 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 2.3.0 @@ -8255,15 +8658,17 @@ object functions { } /** - * Defines a Java UDF10 instance as user-defined function (UDF). - * The caller must specify the output data type, and there is no automatic input type coercion. - * By default the returned UDF is deterministic. To change it to nondeterministic, call the - * API `UserDefinedFunction.asNondeterministic()`. + * Defines a Java UDF10 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 2.3.0 */ - def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + def udf( + f: UDF10[_, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): UserDefinedFunction = { SparkUserDefinedFunction(ToScalaUDF(f), returnType, 10) } @@ -8273,8 +8678,8 @@ object functions { /** * Defines a deterministic user-defined function (UDF) using a Scala closure. For this variant, * the caller must specify the output data type, and there is no automatic input type coercion. - * By default the returned UDF is deterministic. To change it to nondeterministic, call the - * API `UserDefinedFunction.asNondeterministic()`. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. * * Note that, although the Scala closure can have primitive-type function argument, it doesn't * work well with null values. Because the Scala closure is passed in as Any type, there is no @@ -8283,14 +8688,18 @@ object functions { * default value of the Java type for the null argument, e.g. `udf((x: Int) => x, IntegerType)`, * the result is 0 for null input. * - * @param f A closure in Scala - * @param dataType The output data type of the UDF + * @param f + * A closure in Scala + * @param dataType + * The output data type of the UDF * * @group udf_funcs * @since 2.0.0 */ - @deprecated("Scala `udf` method with return type parameter is deprecated. " + - "Please use Scala `udf` method without return type parameter.", "3.0.0") + @deprecated( + "Scala `udf` method with return type parameter is deprecated. " + + "Please use Scala `udf` method without return type parameter.", + "3.0.0") def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = { if (!SqlApiConf.get.legacyAllowUntypedScalaUDFs) { throw CompilationErrors.usingUntypedScalaUDFError() @@ -8309,8 +8718,7 @@ object functions { def callUDF(udfName: String, cols: Column*): Column = call_function(udfName, cols: _*) /** - * Call an user-defined function. - * Example: + * Call an user-defined function. Example: * {{{ * import org.apache.spark.sql._ * @@ -8329,9 +8737,10 @@ object functions { /** * Call a SQL function. * - * @param funcName function name that follows the SQL identifier syntax - * (can be quoted, can be qualified) - * @param cols the expression parameters of function + * @param funcName + * function name that follows the SQL identifier syntax (can be quoted, can be qualified) + * @param cols + * the expression parameters of function * @group normal_funcs * @since 3.5.0 */ @@ -8352,7 +8761,7 @@ object functions { // API in the same way. Once we land this fix, should deprecate // functions.hours, days, months, years and bucket. object partitioning { - // scalastyle:on + // scalastyle:on /** * (Scala-specific) A transform for timestamps and dates to partition data into years. * diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala index 5762f9f6f5668..555a567053080 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala @@ -25,8 +25,8 @@ import org.apache.spark.util.SparkClassUtils /** * Configuration for all objects that are placed in the `sql/api` project. The normal way of - * accessing this class is through `SqlApiConf.get`. If this code is being used with sql/core - * then its values are bound to the currently set SQLConf. With Spark Connect, it will default to + * accessing this class is through `SqlApiConf.get`. If this code is being used with sql/core then + * its values are bound to the currently set SQLConf. With Spark Connect, it will default to * hardcoded values. */ private[sql] trait SqlApiConf { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala index b7b8e14afb387..13ef13e5894e0 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala @@ -21,9 +21,9 @@ import java.util.concurrent.atomic.AtomicReference /** * SqlApiConfHelper is created to avoid a deadlock during a concurrent access to SQLConf and * SqlApiConf, which is because SQLConf and SqlApiConf tries to load each other upon - * initializations. SqlApiConfHelper is private to sql package and is not supposed to be - * accessed by end users. Variables and methods within SqlApiConfHelper are defined to - * be used by SQLConf and SqlApiConf only. + * initializations. SqlApiConfHelper is private to sql package and is not supposed to be accessed + * by end users. Variables and methods within SqlApiConfHelper are defined to be used by SQLConf + * and SqlApiConf only. */ private[sql] object SqlApiConfHelper { // Shared keys. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala index 25ea37fadfa91..4d476108d9ec5 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala @@ -16,14 +16,56 @@ */ package org.apache.spark.sql.internal +import scala.jdk.CollectionConverters._ + +import org.apache.spark.api.java.function.{CoGroupFunction, FilterFunction, FlatMapFunction, FlatMapGroupsFunction, FlatMapGroupsWithStateFunction, ForeachFunction, ForeachPartitionFunction, MapFunction, MapGroupsFunction, MapGroupsWithStateFunction, MapPartitionsFunction, ReduceFunction} import org.apache.spark.sql.api.java._ +import org.apache.spark.sql.streaming.GroupState /** - * Helper class that provides conversions from org.apache.spark.sql.api.java.Function* to - * scala.Function*. + * Helper class that provides conversions from org.apache.spark.sql.api.java.Function* and + * org.apache.spark.api.java.function.* to scala functions. + * + * Please note that this class is being used in Spark Connect Scala UDFs. We need to be careful + * with any modifications to this class, otherwise we will break backwards compatibility. + * Concretely this means you can only add methods to this class. You cannot rename the class, move + * it, change its `serialVersionUID`, remove methods, change method signatures, or change method + * semantics. */ @SerialVersionUID(2019907615267866045L) private[sql] object ToScalaUDF extends Serializable { + def apply[T](f: FilterFunction[T]): T => Boolean = f.call + + def apply[T](f: ReduceFunction[T]): (T, T) => T = f.call + + def apply[V, W](f: MapFunction[V, W]): V => W = f.call + + def apply[K, V, U](f: MapGroupsFunction[K, V, U]): (K, Iterator[V]) => U = + (key, values) => f.call(key, values.asJava) + + def apply[K, V, S, U]( + f: MapGroupsWithStateFunction[K, V, S, U]): (K, Iterator[V], GroupState[S]) => U = + (key, values, state) => f.call(key, values.asJava, state) + + def apply[V, U](f: MapPartitionsFunction[V, U]): Iterator[V] => Iterator[U] = + values => f.call(values.asJava).asScala + + def apply[K, V, U](f: FlatMapGroupsFunction[K, V, U]): (K, Iterator[V]) => Iterator[U] = + (key, values) => f.call(key, values.asJava).asScala + + def apply[K, V, S, U](f: FlatMapGroupsWithStateFunction[K, V, S, U]) + : (K, Iterator[V], GroupState[S]) => Iterator[U] = + (key, values, state) => f.call(key, values.asJava, state).asScala + + def apply[K, V, U, R]( + f: CoGroupFunction[K, V, U, R]): (K, Iterator[V], Iterator[U]) => Iterator[R] = + (key, left, right) => f.call(key, left.asJava, right.asJava).asScala + + def apply[V](f: ForeachFunction[V]): V => Unit = f.call + + def apply[V](f: ForeachPartitionFunction[V]): Iterator[V] => Unit = + values => f.call(values.asJava) + // scalastyle:off line.size.limit /* register 0-22 were generated by this script @@ -38,171 +80,757 @@ private[sql] object ToScalaUDF extends Serializable { |/** | * Create a scala.Function$i wrapper for a org.apache.spark.sql.api.java.UDF$i instance. | */ - |def apply(f: UDF$i[$extTypeArgs]): AnyRef = { + |def apply(f: UDF$i[$extTypeArgs]): Function$i[$anyTypeArgs] = { | $funcCall |}""".stripMargin) } - */ + */ /** * Create a scala.Function0 wrapper for a org.apache.spark.sql.api.java.UDF0 instance. */ - def apply(f: UDF0[_]): AnyRef = { - () => f.asInstanceOf[UDF0[Any]].call() + def apply(f: UDF0[_]): () => Any = { () => + f.asInstanceOf[UDF0[Any]].call() } /** * Create a scala.Function1 wrapper for a org.apache.spark.sql.api.java.UDF1 instance. */ - def apply(f: UDF1[_, _]): AnyRef = { + def apply(f: UDF1[_, _]): (Any) => Any = { f.asInstanceOf[UDF1[Any, Any]].call(_: Any) } /** * Create a scala.Function2 wrapper for a org.apache.spark.sql.api.java.UDF2 instance. */ - def apply(f: UDF2[_, _, _]): AnyRef = { + def apply(f: UDF2[_, _, _]): (Any, Any) => Any = { f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) } /** * Create a scala.Function3 wrapper for a org.apache.spark.sql.api.java.UDF3 instance. */ - def apply(f: UDF3[_, _, _, _]): AnyRef = { + def apply(f: UDF3[_, _, _, _]): (Any, Any, Any) => Any = { f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) } /** * Create a scala.Function4 wrapper for a org.apache.spark.sql.api.java.UDF4 instance. */ - def apply(f: UDF4[_, _, _, _, _]): AnyRef = { + def apply(f: UDF4[_, _, _, _, _]): (Any, Any, Any, Any) => Any = { f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) } /** * Create a scala.Function5 wrapper for a org.apache.spark.sql.api.java.UDF5 instance. */ - def apply(f: UDF5[_, _, _, _, _, _]): AnyRef = { - f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) + def apply(f: UDF5[_, _, _, _, _, _]): (Any, Any, Any, Any, Any) => Any = { + f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]] + .call(_: Any, _: Any, _: Any, _: Any, _: Any) } /** * Create a scala.Function6 wrapper for a org.apache.spark.sql.api.java.UDF6 instance. */ - def apply(f: UDF6[_, _, _, _, _, _, _]): AnyRef = { - f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + def apply(f: UDF6[_, _, _, _, _, _, _]): (Any, Any, Any, Any, Any, Any) => Any = { + f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]] + .call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) } /** * Create a scala.Function7 wrapper for a org.apache.spark.sql.api.java.UDF7 instance. */ - def apply(f: UDF7[_, _, _, _, _, _, _, _]): AnyRef = { - f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + def apply(f: UDF7[_, _, _, _, _, _, _, _]): (Any, Any, Any, Any, Any, Any, Any) => Any = { + f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]] + .call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) } /** * Create a scala.Function8 wrapper for a org.apache.spark.sql.api.java.UDF8 instance. */ - def apply(f: UDF8[_, _, _, _, _, _, _, _, _]): AnyRef = { - f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + def apply( + f: UDF8[_, _, _, _, _, _, _, _, _]): (Any, Any, Any, Any, Any, Any, Any, Any) => Any = { + f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]] + .call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) } /** * Create a scala.Function9 wrapper for a org.apache.spark.sql.api.java.UDF9 instance. */ - def apply(f: UDF9[_, _, _, _, _, _, _, _, _, _]): AnyRef = { - f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + def apply(f: UDF9[_, _, _, _, _, _, _, _, _, _]) + : (Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any = { + f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]] + .call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) } /** * Create a scala.Function10 wrapper for a org.apache.spark.sql.api.java.UDF10 instance. */ - def apply(f: UDF10[_, _, _, _, _, _, _, _, _, _, _]): AnyRef = { - f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + def apply(f: UDF10[_, _, _, _, _, _, _, _, _, _, _]) + : Function10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = { + f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]] + .call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) } /** * Create a scala.Function11 wrapper for a org.apache.spark.sql.api.java.UDF11 instance. */ - def apply(f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { - f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + def apply(f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _]) + : Function11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = { + f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]] + .call( + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any) } /** * Create a scala.Function12 wrapper for a org.apache.spark.sql.api.java.UDF12 instance. */ - def apply(f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { - f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + def apply(f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]) + : Function12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = { + f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]] + .call( + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any) } /** * Create a scala.Function13 wrapper for a org.apache.spark.sql.api.java.UDF13 instance. */ - def apply(f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { - f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + def apply(f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]) + : Function13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = { + f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]] + .call( + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any) } /** * Create a scala.Function14 wrapper for a org.apache.spark.sql.api.java.UDF14 instance. */ - def apply(f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { - f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + def apply(f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]) + : Function14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = { + f.asInstanceOf[ + UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]] + .call( + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any) } /** * Create a scala.Function15 wrapper for a org.apache.spark.sql.api.java.UDF15 instance. */ - def apply(f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { - f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + def apply(f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function15[ + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any] = { + f.asInstanceOf[ + UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]] + .call( + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any) } /** * Create a scala.Function16 wrapper for a org.apache.spark.sql.api.java.UDF16 instance. */ - def apply(f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { - f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + def apply(f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function16[ + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any] = { + f.asInstanceOf[ + UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]] + .call( + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any) } /** * Create a scala.Function17 wrapper for a org.apache.spark.sql.api.java.UDF17 instance. */ - def apply(f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { - f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + def apply(f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function17[ + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any] = { + f.asInstanceOf[UDF17[ + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any]] + .call( + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any) } /** * Create a scala.Function18 wrapper for a org.apache.spark.sql.api.java.UDF18 instance. */ - def apply(f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { - f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + def apply(f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function18[ + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any] = { + f.asInstanceOf[UDF18[ + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any]] + .call( + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any) } /** * Create a scala.Function19 wrapper for a org.apache.spark.sql.api.java.UDF19 instance. */ - def apply(f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { - f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + def apply(f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function19[ + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any] = { + f.asInstanceOf[UDF19[ + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any]] + .call( + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any) } /** * Create a scala.Function20 wrapper for a org.apache.spark.sql.api.java.UDF20 instance. */ - def apply(f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { - f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + def apply(f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function20[ + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any] = { + f.asInstanceOf[UDF20[ + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any]] + .call( + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any) } /** * Create a scala.Function21 wrapper for a org.apache.spark.sql.api.java.UDF21 instance. */ - def apply(f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { - f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + def apply( + f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function21[ + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any] = { + f.asInstanceOf[UDF21[ + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any]] + .call( + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any) } /** * Create a scala.Function22 wrapper for a org.apache.spark.sql.api.java.UDF22 instance. */ - def apply(f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = { - f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + def apply( + f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function22[ + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any] = { + f.asInstanceOf[UDF22[ + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any]] + .call( + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any, + _: Any) } // scalastyle:on line.size.limit } + +/** + * Adaptors from one UDF shape to another. For example adapting a foreach function for use in + * foreachPartition. + * + * Please note that this class is being used in Spark Connect Scala UDFs. We need to be careful + * with any modifications to this class, otherwise we will break backwards compatibility. + * Concretely this means you can only add methods to this class. You cannot rename the class, move + * it, change its `serialVersionUID`, remove methods, change method signatures, or change method + * semantics. + */ +@SerialVersionUID(0L) // TODO +object UDFAdaptors extends Serializable { + def flatMapToMapPartitions[V, U](f: V => IterableOnce[U]): Iterator[V] => Iterator[U] = + values => values.flatMap(f) + + def flatMapToMapPartitions[V, U](f: FlatMapFunction[V, U]): Iterator[V] => Iterator[U] = + values => values.flatMap(v => f.call(v).asScala) + + def mapToMapPartitions[V, U](f: V => U): Iterator[V] => Iterator[U] = values => values.map(f) + + def mapToMapPartitions[V, U](f: MapFunction[V, U]): Iterator[V] => Iterator[U] = + values => values.map(f.call) + + def foreachToForeachPartition[T](f: T => Unit): Iterator[T] => Unit = + values => values.foreach(f) + + def foreachToForeachPartition[T](f: ForeachFunction[T]): Iterator[T] => Unit = + values => values.foreach(f.call) + + def foreachPartitionToMapPartitions[V, U](f: Iterator[V] => Unit): Iterator[V] => Iterator[U] = + values => { + f(values) + Iterator.empty[U] + } + + def iterableOnceToSeq[A, B](f: A => IterableOnce[B]): A => Seq[B] = + value => f(value).iterator.toSeq + + def mapGroupsToFlatMapGroups[K, V, U]( + f: (K, Iterator[V]) => U): (K, Iterator[V]) => Iterator[U] = + (key, values) => Iterator.single(f(key, values)) + + def mapGroupsWithStateToFlatMapWithState[K, V, S, U]( + f: (K, Iterator[V], GroupState[S]) => U): (K, Iterator[V], GroupState[S]) => Iterator[U] = + (key: K, values: Iterator[V], state: GroupState[S]) => Iterator(f(key, values, state)) + + def coGroupWithMappedValues[K, V, U, R, IV, IU]( + f: (K, Iterator[V], Iterator[U]) => IterableOnce[R], + leftValueMapFunc: Option[IV => V], + rightValueMapFunc: Option[IU => U]): (K, Iterator[IV], Iterator[IU]) => IterableOnce[R] = { + (leftValueMapFunc, rightValueMapFunc) match { + case (None, None) => + f.asInstanceOf[(K, Iterator[IV], Iterator[IU]) => IterableOnce[R]] + case (Some(mapLeft), None) => + (k, left, right) => f(k, left.map(mapLeft), right.asInstanceOf[Iterator[U]]) + case (None, Some(mapRight)) => + (k, left, right) => f(k, left.asInstanceOf[Iterator[V]], right.map(mapRight)) + case (Some(mapLeft), Some(mapRight)) => + (k, left, right) => f(k, left.map(mapLeft), right.map(mapRight)) + } + } + + def flatMapGroupsWithMappedValues[K, IV, V, R]( + f: (K, Iterator[V]) => IterableOnce[R], + valueMapFunc: Option[IV => V]): (K, Iterator[IV]) => IterableOnce[R] = valueMapFunc match { + case Some(mapValue) => (k, values) => f(k, values.map(mapValue)) + case None => f.asInstanceOf[(K, Iterator[IV]) => IterableOnce[R]] + } + + def flatMapGroupsWithStateWithMappedValues[K, IV, V, S, U]( + f: (K, Iterator[V], GroupState[S]) => Iterator[U], + valueMapFunc: Option[IV => V]): (K, Iterator[IV], GroupState[S]) => Iterator[U] = { + valueMapFunc match { + case Some(mapValue) => (k, values, state) => f(k, values.map(mapValue), state) + case None => f.asInstanceOf[(K, Iterator[IV], GroupState[S]) => Iterator[U]] + } + } +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala index 9d77b2e6f3e22..51b26a1fa2435 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala @@ -32,10 +32,12 @@ import org.apache.spark.util.SparkClassUtils * implementation specific form (e.g. Catalyst expressions, or Connect protobuf messages). * * This API is a mirror image of Connect's expression.proto. There are a couple of extensions to - * make constructing nodes easier (e.g. [[CaseWhenOtherwise]]). We could not use the actual connect - * protobuf messages because of classpath clashes (e.g. Guava & gRPC) and Maven shading issues. + * make constructing nodes easier (e.g. [[CaseWhenOtherwise]]). We could not use the actual + * connect protobuf messages because of classpath clashes (e.g. Guava & gRPC) and Maven shading + * issues. */ private[sql] trait ColumnNode extends ColumnNodeLike { + /** * Origin where the node was created. */ @@ -93,13 +95,17 @@ private[internal] object ColumnNode { /** * A literal column. * - * @param value of the literal. This is the unconverted input value. - * @param dataType of the literal. If none is provided the dataType is inferred. + * @param value + * of the literal. This is the unconverted input value. + * @param dataType + * of the literal. If none is provided the dataType is inferred. */ private[sql] case class Literal( value: Any, dataType: Option[DataType] = None, - override val origin: Origin = CurrentOrigin.get) extends ColumnNode with DataTypeErrorsBase { + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode + with DataTypeErrorsBase { override private[internal] def normalize(): Literal = copy(origin = NO_ORIGIN) override def sql: String = value match { @@ -116,16 +122,19 @@ private[sql] case class Literal( /** * Reference to an attribute produced by one of the underlying DataFrames. * - * @param unparsedIdentifier name of the attribute. - * @param planId id of the plan (Dataframe) that produces the attribute. - * @param isMetadataColumn whether this is a metadata column. + * @param unparsedIdentifier + * name of the attribute. + * @param planId + * id of the plan (Dataframe) that produces the attribute. + * @param isMetadataColumn + * whether this is a metadata column. */ private[sql] case class UnresolvedAttribute( unparsedIdentifier: String, planId: Option[Long] = None, isMetadataColumn: Boolean = false, override val origin: Origin = CurrentOrigin.get) - extends ColumnNode { + extends ColumnNode { override private[internal] def normalize(): UnresolvedAttribute = copy(planId = None, origin = NO_ORIGIN) override def sql: String = unparsedIdentifier @@ -134,14 +143,16 @@ private[sql] case class UnresolvedAttribute( /** * Reference to all columns in a namespace (global, a Dataframe, or a nested struct). * - * @param unparsedTarget name of the namespace. None if the global namespace is supposed to be used. - * @param planId id of the plan (Dataframe) that produces the attribute. + * @param unparsedTarget + * name of the namespace. None if the global namespace is supposed to be used. + * @param planId + * id of the plan (Dataframe) that produces the attribute. */ private[sql] case class UnresolvedStar( unparsedTarget: Option[String], planId: Option[Long] = None, override val origin: Origin = CurrentOrigin.get) - extends ColumnNode { + extends ColumnNode { override private[internal] def normalize(): UnresolvedStar = copy(planId = None, origin = NO_ORIGIN) override def sql: String = unparsedTarget.map(_ + ".*").getOrElse("*") @@ -151,10 +162,12 @@ private[sql] case class UnresolvedStar( * Call a function. This can either be a built-in function, a UDF, or a UDF registered in the * Catalog. * - * @param functionName of the function to invoke. - * @param arguments to pass into the function. - * @param isDistinct (aggregate only) whether the input of the aggregate function should be - * de-duplicated. + * @param functionName + * of the function to invoke. + * @param arguments + * to pass into the function. + * @param isDistinct + * (aggregate only) whether the input of the aggregate function should be de-duplicated. */ private[sql] case class UnresolvedFunction( functionName: String, @@ -163,7 +176,7 @@ private[sql] case class UnresolvedFunction( isUserDefinedFunction: Boolean = false, isInternal: Boolean = false, override val origin: Origin = CurrentOrigin.get) - extends ColumnNode { + extends ColumnNode { override private[internal] def normalize(): UnresolvedFunction = copy(arguments = ColumnNode.normalize(arguments), origin = NO_ORIGIN) @@ -173,11 +186,13 @@ private[sql] case class UnresolvedFunction( /** * Evaluate a SQL expression. * - * @param expression text to execute. + * @param expression + * text to execute. */ private[sql] case class SqlExpression( expression: String, - override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode { override private[internal] def normalize(): SqlExpression = copy(origin = NO_ORIGIN) override def sql: String = expression } @@ -185,15 +200,19 @@ private[sql] case class SqlExpression( /** * Name a column, and (optionally) modify its metadata. * - * @param child to name - * @param name to use - * @param metadata (optional) metadata to add. + * @param child + * to name + * @param name + * to use + * @param metadata + * (optional) metadata to add. */ private[sql] case class Alias( child: ColumnNode, name: Seq[String], metadata: Option[Metadata] = None, - override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode { override private[internal] def normalize(): Alias = copy(child = child.normalize(), origin = NO_ORIGIN) @@ -210,15 +229,19 @@ private[sql] case class Alias( * Cast the value of a Column to a different [[DataType]]. The behavior of the cast can be * influenced by the `evalMode`. * - * @param child that produces the input value. - * @param dataType to cast to. - * @param evalMode (try/ansi/legacy) to use for the cast. + * @param child + * that produces the input value. + * @param dataType + * to cast to. + * @param evalMode + * (try/ansi/legacy) to use for the cast. */ private[sql] case class Cast( child: ColumnNode, dataType: DataType, evalMode: Option[Cast.EvalMode] = None, - override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode { override private[internal] def normalize(): Cast = copy(child = child.normalize(), origin = NO_ORIGIN) @@ -237,13 +260,16 @@ private[sql] object Cast { /** * Reference to all columns in the global namespace in that match a regex. * - * @param regex name of the namespace. None if the global namespace is supposed to be used. - * @param planId id of the plan (Dataframe) that produces the attribute. + * @param regex + * name of the namespace. None if the global namespace is supposed to be used. + * @param planId + * id of the plan (Dataframe) that produces the attribute. */ private[sql] case class UnresolvedRegex( regex: String, planId: Option[Long] = None, - override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode { override private[internal] def normalize(): UnresolvedRegex = copy(planId = None, origin = NO_ORIGIN) override def sql: String = regex @@ -252,16 +278,19 @@ private[sql] case class UnresolvedRegex( /** * Sort the input column. * - * @param child to sort. - * @param sortDirection to sort in, either Ascending or Descending. - * @param nullOrdering where to place nulls, either at the begin or the end. + * @param child + * to sort. + * @param sortDirection + * to sort in, either Ascending or Descending. + * @param nullOrdering + * where to place nulls, either at the begin or the end. */ private[sql] case class SortOrder( child: ColumnNode, sortDirection: SortOrder.SortDirection, nullOrdering: SortOrder.NullOrdering, override val origin: Origin = CurrentOrigin.get) - extends ColumnNode { + extends ColumnNode { override private[internal] def normalize(): SortOrder = copy(child = child.normalize(), origin = NO_ORIGIN) @@ -280,14 +309,16 @@ private[sql] object SortOrder { /** * Evaluate a function within a window. * - * @param windowFunction function to execute. - * @param windowSpec of the window. + * @param windowFunction + * function to execute. + * @param windowSpec + * of the window. */ private[sql] case class Window( windowFunction: ColumnNode, windowSpec: WindowSpec, override val origin: Origin = CurrentOrigin.get) - extends ColumnNode { + extends ColumnNode { override private[internal] def normalize(): Window = copy( windowFunction = windowFunction.normalize(), windowSpec = windowSpec.normalize(), @@ -299,7 +330,8 @@ private[sql] case class Window( private[sql] case class WindowSpec( partitionColumns: Seq[ColumnNode], sortColumns: Seq[SortOrder], - frame: Option[WindowFrame] = None) extends ColumnNodeLike { + frame: Option[WindowFrame] = None) + extends ColumnNodeLike { override private[internal] def normalize(): WindowSpec = copy( partitionColumns = ColumnNode.normalize(partitionColumns), sortColumns = ColumnNode.normalize(sortColumns), @@ -317,7 +349,7 @@ private[sql] case class WindowFrame( frameType: WindowFrame.FrameType, lower: WindowFrame.FrameBoundary, upper: WindowFrame.FrameBoundary) - extends ColumnNodeLike { + extends ColumnNodeLike { override private[internal] def normalize(): WindowFrame = copy(lower = lower.normalize(), upper = upper.normalize()) override private[internal] def sql: String = @@ -352,13 +384,16 @@ private[sql] object WindowFrame { /** * Lambda function to execute. This typically passed as an argument to a function. * - * @param function to execute. - * @param arguments the bound lambda variables. + * @param function + * to execute. + * @param arguments + * the bound lambda variables. */ private[sql] case class LambdaFunction( function: ColumnNode, arguments: Seq[UnresolvedNamedLambdaVariable], - override val origin: Origin) extends ColumnNode { + override val origin: Origin) + extends ColumnNode { override private[internal] def normalize(): LambdaFunction = copy( function = function.normalize(), @@ -382,11 +417,13 @@ object LambdaFunction { /** * Variable used in a [[LambdaFunction]]. * - * @param name of the variable. + * @param name + * of the variable. */ private[sql] case class UnresolvedNamedLambdaVariable( name: String, - override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode { override private[internal] def normalize(): UnresolvedNamedLambdaVariable = copy(origin = NO_ORIGIN) @@ -413,21 +450,22 @@ object UnresolvedNamedLambdaVariable { } /** - * Extract a value from a complex type. This can be a field from a struct, a value from a map, - * or an element from an array. + * Extract a value from a complex type. This can be a field from a struct, a value from a map, or + * an element from an array. * - * @param child that produces a complex value. - * @param extraction that is used to access the complex type. This needs to be a string type for - * structs and maps, and it needs to be an integer for arrays. + * @param child + * that produces a complex value. + * @param extraction + * that is used to access the complex type. This needs to be a string type for structs and maps, + * and it needs to be an integer for arrays. */ private[sql] case class UnresolvedExtractValue( child: ColumnNode, extraction: ColumnNode, - override val origin: Origin = CurrentOrigin.get) extends ColumnNode { - override private[internal] def normalize(): UnresolvedExtractValue = copy( - child = child.normalize(), - extraction = extraction.normalize(), - origin = NO_ORIGIN) + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode { + override private[internal] def normalize(): UnresolvedExtractValue = + copy(child = child.normalize(), extraction = extraction.normalize(), origin = NO_ORIGIN) override def sql: String = s"${child.sql}[${extraction.sql}]" } @@ -435,15 +473,19 @@ private[sql] case class UnresolvedExtractValue( /** * Update or drop the field of a struct. * - * @param structExpression that will be updated. - * @param fieldName name of the field to update. - * @param valueExpression new value of the field. If this is None the field will be dropped. + * @param structExpression + * that will be updated. + * @param fieldName + * name of the field to update. + * @param valueExpression + * new value of the field. If this is None the field will be dropped. */ private[sql] case class UpdateFields( structExpression: ColumnNode, fieldName: String, valueExpression: Option[ColumnNode] = None, - override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode { override private[internal] def normalize(): UpdateFields = copy( structExpression = structExpression.normalize(), valueExpression = ColumnNode.normalize(valueExpression), @@ -455,18 +497,20 @@ private[sql] case class UpdateFields( } /** - * Evaluate one or more conditional branches. The value of the first branch for which the predicate - * evalutes to true is returned. If none of the branches evaluate to true, the value of `otherwise` - * is returned. + * Evaluate one or more conditional branches. The value of the first branch for which the + * predicate evalutes to true is returned. If none of the branches evaluate to true, the value of + * `otherwise` is returned. * - * @param branches to evaluate. Each entry if a pair of condition and value. - * @param otherwise (optional) to evaluate when none of the branches evaluate to true. + * @param branches + * to evaluate. Each entry if a pair of condition and value. + * @param otherwise + * (optional) to evaluate when none of the branches evaluate to true. */ private[sql] case class CaseWhenOtherwise( branches: Seq[(ColumnNode, ColumnNode)], otherwise: Option[ColumnNode] = None, override val origin: Origin = CurrentOrigin.get) - extends ColumnNode { + extends ColumnNode { assert(branches.nonEmpty) override private[internal] def normalize(): CaseWhenOtherwise = copy( branches = branches.map(kv => (kv._1.normalize(), kv._2.normalize())), @@ -483,15 +527,17 @@ private[sql] case class CaseWhenOtherwise( /** * Invoke an inline user defined function. * - * @param function to invoke. - * @param arguments to pass into the user defined function. + * @param function + * to invoke. + * @param arguments + * to pass into the user defined function. */ private[sql] case class InvokeInlineUserDefinedFunction( function: UserDefinedFunctionLike, arguments: Seq[ColumnNode], isDistinct: Boolean = false, override val origin: Origin = CurrentOrigin.get) - extends ColumnNode { + extends ColumnNode { override private[internal] def normalize(): InvokeInlineUserDefinedFunction = copy(arguments = ColumnNode.normalize(arguments), origin = NO_ORIGIN) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractArrayType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractArrayType.scala index 406449a337271..5c8c77985bb2c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractArrayType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractArrayType.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.internal.types import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType} - /** * Use AbstractArrayType(AbstractDataType) for defining expected types for expression parameters. */ @@ -30,7 +29,7 @@ case class AbstractArrayType(elementType: AbstractDataType) extends AbstractData override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[ArrayType] && - elementType.acceptsType(other.asInstanceOf[ArrayType].elementType) + elementType.acceptsType(other.asInstanceOf[ArrayType].elementType) } override private[spark] def simpleString: String = s"array<${elementType.simpleString}>" diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractMapType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractMapType.scala index 62f422f6f80a7..32f4341839f01 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractMapType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractMapType.scala @@ -19,23 +19,20 @@ package org.apache.spark.sql.internal.types import org.apache.spark.sql.types.{AbstractDataType, DataType, MapType} - /** - * Use AbstractMapType(AbstractDataType, AbstractDataType) - * for defining expected types for expression parameters. + * Use AbstractMapType(AbstractDataType, AbstractDataType) for defining expected types for + * expression parameters. */ -case class AbstractMapType( - keyType: AbstractDataType, - valueType: AbstractDataType - ) extends AbstractDataType { +case class AbstractMapType(keyType: AbstractDataType, valueType: AbstractDataType) + extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = MapType(keyType.defaultConcreteType, valueType.defaultConcreteType, valueContainsNull = true) override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[MapType] && - keyType.acceptsType(other.asInstanceOf[MapType].keyType) && - valueType.acceptsType(other.asInstanceOf[MapType].valueType) + keyType.acceptsType(other.asInstanceOf[MapType].keyType) && + valueType.acceptsType(other.asInstanceOf[MapType].valueType) } override private[spark] def simpleString: String = diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index 05d1701eff74d..dc4ee013fd189 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -51,3 +51,12 @@ case object StringTypeBinaryLcase extends AbstractStringType { case object StringTypeAnyCollation extends AbstractStringType { override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] } + +/** + * Use StringTypeNonCSAICollation for expressions supporting all possible collation types except + * CS_AI collation types. + */ +case object StringTypeNonCSAICollation extends AbstractStringType { + override private[sql] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala index 49dc393f8481b..a0958aceb3b3a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala @@ -22,12 +22,13 @@ import java.io.Serializable import org.apache.spark.annotation.{Evolving, Experimental} /** - * Class used to provide access to expired timer's expiry time. These values - * are only relevant if the ExpiredTimerInfo is valid. + * Class used to provide access to expired timer's expiry time. These values are only relevant if + * the ExpiredTimerInfo is valid. */ @Experimental @Evolving private[sql] trait ExpiredTimerInfo extends Serializable { + /** * Check if provided ExpiredTimerInfo is valid. */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index f08a2fd3cc55c..146990917a3fc 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -27,89 +27,89 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState * `flatMapGroupsWithState` operations on `KeyValueGroupedDataset`. * * Detail description on `[map/flatMap]GroupsWithState` operation - * -------------------------------------------------------------- - * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in `KeyValueGroupedDataset` - * will invoke the user-given function on each group (defined by the grouping function in - * `Dataset.groupByKey()`) while maintaining a user-defined per-group state between invocations. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger. - * That is, in every batch of the `StreamingQuery`, - * the function will be invoked once for each group that has data in the trigger. Furthermore, - * if timeout is set, then the function will be invoked on timed-out groups (more detail below). + * -------------------------------------------------------------- Both, `mapGroupsWithState` and + * `flatMapGroupsWithState` in `KeyValueGroupedDataset` will invoke the user-given function on + * each group (defined by the grouping function in `Dataset.groupByKey()`) while maintaining a + * user-defined per-group state between invocations. For a static batch Dataset, the function will + * be invoked once per group. For a streaming Dataset, the function will be invoked for each group + * repeatedly in every trigger. That is, in every batch of the `StreamingQuery`, the function will + * be invoked once for each group that has data in the trigger. Furthermore, if timeout is set, + * then the function will be invoked on timed-out groups (more detail below). * * The function is invoked with the following parameters. - * - The key of the group. - * - An iterator containing all the values for this group. - * - A user-defined state object set by previous invocations of the given function. + * - The key of the group. + * - An iterator containing all the values for this group. + * - A user-defined state object set by previous invocations of the given function. * * In case of a batch Dataset, there is only one invocation and the state object will be empty as - * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` - * is equivalent to `[map/flatMap]Groups` and any updates to the state and/or timeouts have - * no effect. + * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` is + * equivalent to `[map/flatMap]Groups` and any updates to the state and/or timeouts have no + * effect. * * The major difference between `mapGroupsWithState` and `flatMapGroupsWithState` is that the - * former allows the function to return one and only one record, whereas the latter - * allows the function to return any number of records (including no records). Furthermore, the + * former allows the function to return one and only one record, whereas the latter allows the + * function to return any number of records (including no records). Furthermore, the * `flatMapGroupsWithState` is associated with an operation output mode, which can be either - * `Append` or `Update`. Semantically, this defines whether the output records of one trigger - * is effectively replacing the previously output records (from previous triggers) or is appending - * to the list of previously output records. Essentially, this defines how the Result Table (refer - * to the semantics in the programming guide) is updated, and allows us to reason about the - * semantics of later operations. + * `Append` or `Update`. Semantically, this defines whether the output records of one trigger is + * effectively replacing the previously output records (from previous triggers) or is appending to + * the list of previously output records. Essentially, this defines how the Result Table (refer to + * the semantics in the programming guide) is updated, and allows us to reason about the semantics + * of later operations. * - * Important points to note about the function (both mapGroupsWithState and flatMapGroupsWithState). - * - In a trigger, the function will be called only the groups present in the batch. So do not - * assume that the function will be called in every trigger for every group that has state. - * - There is no guaranteed ordering of values in the iterator in the function, neither with - * batch, nor with streaming Datasets. - * - All the data will be shuffled before applying the function. - * - If timeout is set, then the function will also be called with no values. - * See more details on `GroupStateTimeout` below. + * Important points to note about the function (both mapGroupsWithState and + * flatMapGroupsWithState). + * - In a trigger, the function will be called only the groups present in the batch. So do not + * assume that the function will be called in every trigger for every group that has state. + * - There is no guaranteed ordering of values in the iterator in the function, neither with + * batch, nor with streaming Datasets. + * - All the data will be shuffled before applying the function. + * - If timeout is set, then the function will also be called with no values. See more details + * on `GroupStateTimeout` below. * * Important points to note about using `GroupState`. - * - The value of the state cannot be null. So updating state with null will throw - * `IllegalArgumentException`. - * - Operations on `GroupState` are not thread-safe. This is to avoid memory barriers. - * - If `remove()` is called, then `exists()` will return `false`, - * `get()` will throw `NoSuchElementException` and `getOption()` will return `None` - * - After that, if `update(newState)` is called, then `exists()` will again return `true`, - * `get()` and `getOption()`will return the updated value. + * - The value of the state cannot be null. So updating state with null will throw + * `IllegalArgumentException`. + * - Operations on `GroupState` are not thread-safe. This is to avoid memory barriers. + * - If `remove()` is called, then `exists()` will return `false`, `get()` will throw + * `NoSuchElementException` and `getOption()` will return `None` + * - After that, if `update(newState)` is called, then `exists()` will again return `true`, + * `get()` and `getOption()`will return the updated value. * * Important points to note about using `GroupStateTimeout`. - * - The timeout type is a global param across all the groups (set as `timeout` param in - * `[map|flatMap]GroupsWithState`, but the exact timeout duration/timestamp is configurable per - * group by calling `setTimeout...()` in `GroupState`. - * - Timeouts can be either based on processing time (i.e. - * `GroupStateTimeout.ProcessingTimeTimeout`) or event time (i.e. - * `GroupStateTimeout.EventTimeTimeout`). - * - With `ProcessingTimeTimeout`, the timeout duration can be set by calling - * `GroupState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the set - * duration. Guarantees provided by this timeout with a duration of D ms are as follows: - * - Timeout will never occur before the clock time has advanced by D ms - * - Timeout will occur eventually when there is a trigger in the query - * (i.e. after D ms). So there is no strict upper bound on when the timeout would occur. - * For example, the trigger interval of the query will affect when the timeout actually occurs. - * If there is no data in the stream (for any group) for a while, then there will not be - * any trigger and timeout function call will not occur until there is data. - * - Since the processing time timeout is based on the clock time, it is affected by the - * variations in the system clock (i.e. time zone changes, clock skew, etc.). - * - With `EventTimeTimeout`, the user also has to specify the event time watermark in - * the query using `Dataset.withWatermark()`. With this setting, data that is older than the - * watermark is filtered out. The timeout can be set for a group by setting a timeout timestamp - * using`GroupState.setTimeoutTimestamp()`, and the timeout would occur when the watermark - * advances beyond the set timestamp. You can control the timeout delay by two parameters - - * (i) watermark delay and an additional duration beyond the timestamp in the event (which - * is guaranteed to be newer than watermark due to the filtering). Guarantees provided by this - * timeout are as follows: - * - Timeout will never occur before the watermark has exceeded the set timeout. - * - Similar to processing time timeouts, there is no strict upper bound on the delay when - * the timeout actually occurs. The watermark can advance only when there is data in the - * stream and the event time of the data has actually advanced. - * - When the timeout occurs for a group, the function is called for that group with no values, and - * `GroupState.hasTimedOut()` set to true. - * - The timeout is reset every time the function is called on a group, that is, - * when the group has new data, or the group has timed out. So the user has to set the timeout - * duration every time the function is called, otherwise, there will not be any timeout set. + * - The timeout type is a global param across all the groups (set as `timeout` param in + * `[map|flatMap]GroupsWithState`, but the exact timeout duration/timestamp is configurable + * per group by calling `setTimeout...()` in `GroupState`. + * - Timeouts can be either based on processing time (i.e. + * `GroupStateTimeout.ProcessingTimeTimeout`) or event time (i.e. + * `GroupStateTimeout.EventTimeTimeout`). + * - With `ProcessingTimeTimeout`, the timeout duration can be set by calling + * `GroupState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the + * set duration. Guarantees provided by this timeout with a duration of D ms are as follows: + * - Timeout will never occur before the clock time has advanced by D ms + * - Timeout will occur eventually when there is a trigger in the query (i.e. after D ms). So + * there is no strict upper bound on when the timeout would occur. For example, the trigger + * interval of the query will affect when the timeout actually occurs. If there is no data + * in the stream (for any group) for a while, then there will not be any trigger and timeout + * function call will not occur until there is data. + * - Since the processing time timeout is based on the clock time, it is affected by the + * variations in the system clock (i.e. time zone changes, clock skew, etc.). + * - With `EventTimeTimeout`, the user also has to specify the event time watermark in the query + * using `Dataset.withWatermark()`. With this setting, data that is older than the watermark + * is filtered out. The timeout can be set for a group by setting a timeout timestamp + * using`GroupState.setTimeoutTimestamp()`, and the timeout would occur when the watermark + * advances beyond the set timestamp. You can control the timeout delay by two parameters - + * (i) watermark delay and an additional duration beyond the timestamp in the event (which is + * guaranteed to be newer than watermark due to the filtering). Guarantees provided by this + * timeout are as follows: + * - Timeout will never occur before the watermark has exceeded the set timeout. + * - Similar to processing time timeouts, there is no strict upper bound on the delay when the + * timeout actually occurs. The watermark can advance only when there is data in the stream + * and the event time of the data has actually advanced. + * - When the timeout occurs for a group, the function is called for that group with no values, + * and `GroupState.hasTimedOut()` set to true. + * - The timeout is reset every time the function is called on a group, that is, when the group + * has new data, or the group has timed out. So the user has to set the timeout duration every + * time the function is called, otherwise, there will not be any timeout set. * * `[map/flatMap]GroupsWithState` can take a user defined initial state as an additional argument. * This state will be applied when the first batch of the streaming query is processed. If there @@ -181,7 +181,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState * state.setTimeoutDuration("1 hour"); // Set the timeout * } * ... -* // return something + * // return something * } * }; * @@ -191,8 +191,9 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState * mappingFunction, Encoders.INT, Encoders.STRING, GroupStateTimeout.ProcessingTimeTimeout); * }}} * - * @tparam S User-defined type of the state to be stored for each group. Must be encodable into - * Spark SQL types (see `Encoder` for more details). + * @tparam S + * User-defined type of the state to be stored for each group. Must be encodable into Spark SQL + * types (see `Encoder` for more details). * @since 2.2.0 */ @Experimental @@ -217,44 +218,48 @@ trait GroupState[S] extends LogicalGroupState[S] { /** * Whether the function has been called because the key has timed out. - * @note This can return true only when timeouts are enabled in `[map/flatMap]GroupsWithState`. + * @note + * This can return true only when timeouts are enabled in `[map/flatMap]GroupsWithState`. */ def hasTimedOut: Boolean - /** * Set the timeout duration in ms for this key. * - * @note [[GroupStateTimeout Processing time timeout]] must be enabled in - * `[map/flatMap]GroupsWithState` for calling this method. - * @note This method has no effect when used in a batch query. + * @note + * [[GroupStateTimeout Processing time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note + * This method has no effect when used in a batch query. */ @throws[IllegalArgumentException]("if 'durationMs' is not positive") @throws[UnsupportedOperationException]( "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutDuration(durationMs: Long): Unit - /** * Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc. * - * @note [[GroupStateTimeout Processing time timeout]] must be enabled in - * `[map/flatMap]GroupsWithState` for calling this method. - * @note This method has no effect when used in a batch query. + * @note + * [[GroupStateTimeout Processing time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note + * This method has no effect when used in a batch query. */ @throws[IllegalArgumentException]("if 'duration' is not a valid duration") @throws[UnsupportedOperationException]( "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutDuration(duration: String): Unit - /** - * Set the timeout timestamp for this key as milliseconds in epoch time. - * This timestamp cannot be older than the current watermark. + * Set the timeout timestamp for this key as milliseconds in epoch time. This timestamp cannot + * be older than the current watermark. * - * @note [[GroupStateTimeout Event time timeout]] must be enabled in - * `[map/flatMap]GroupsWithState` for calling this method. - * @note This method has no effect when used in a batch query. + * @note + * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState` + * for calling this method. + * @note + * This method has no effect when used in a batch query. */ @throws[IllegalArgumentException]( "if 'timestampMs' is not positive or less than the current watermark in a streaming query") @@ -262,16 +267,16 @@ trait GroupState[S] extends LogicalGroupState[S] { "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestampMs: Long): Unit - /** * Set the timeout timestamp for this key as milliseconds in epoch time and an additional - * duration as a string (e.g. "1 hour", "2 days", etc.). - * The final timestamp (including the additional duration) cannot be older than the - * current watermark. + * duration as a string (e.g. "1 hour", "2 days", etc.). The final timestamp (including the + * additional duration) cannot be older than the current watermark. * - * @note [[GroupStateTimeout Event time timeout]] must be enabled in - * `[map/flatMap]GroupsWithState` for calling this method. - * @note This method has no side effect when used in a batch query. + * @note + * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState` + * for calling this method. + * @note + * This method has no side effect when used in a batch query. */ @throws[IllegalArgumentException]( "if 'additionalDuration' is invalid or the final timeout timestamp is less than " + @@ -280,56 +285,57 @@ trait GroupState[S] extends LogicalGroupState[S] { "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit - /** - * Set the timeout timestamp for this key as a java.sql.Date. - * This timestamp cannot be older than the current watermark. + * Set the timeout timestamp for this key as a java.sql.Date. This timestamp cannot be older + * than the current watermark. * - * @note [[GroupStateTimeout Event time timeout]] must be enabled in - * `[map/flatMap]GroupsWithState` for calling this method. - * @note This method has no side effect when used in a batch query. + * @note + * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState` + * for calling this method. + * @note + * This method has no side effect when used in a batch query. */ @throws[UnsupportedOperationException]( "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestamp: java.sql.Date): Unit - /** - * Set the timeout timestamp for this key as a java.sql.Date and an additional - * duration as a string (e.g. "1 hour", "2 days", etc.). - * The final timestamp (including the additional duration) cannot be older than the - * current watermark. + * Set the timeout timestamp for this key as a java.sql.Date and an additional duration as a + * string (e.g. "1 hour", "2 days", etc.). The final timestamp (including the additional + * duration) cannot be older than the current watermark. * - * @note [[GroupStateTimeout Event time timeout]] must be enabled in - * `[map/flatMap]GroupsWithState` for calling this method. - * @note This method has no side effect when used in a batch query. + * @note + * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState` + * for calling this method. + * @note + * This method has no side effect when used in a batch query. */ @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") @throws[UnsupportedOperationException]( "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestamp: java.sql.Date, additionalDuration: String): Unit - /** * Get the current event time watermark as milliseconds in epoch time. * - * @note In a streaming query, this can be called only when watermark is set before calling - * `[map/flatMap]GroupsWithState`. In a batch query, this method always returns -1. - * @note The watermark gets propagated in the end of each query. As a result, this method will - * return 0 (1970-01-01T00:00:00) for the first micro-batch. If you use this value - * as a part of the timestamp set in the `setTimeoutTimestamp`, it may lead to the - * state expiring immediately in the next micro-batch, once the watermark gets the - * real value from your data. + * @note + * In a streaming query, this can be called only when watermark is set before calling + * `[map/flatMap]GroupsWithState`. In a batch query, this method always returns -1. + * @note + * The watermark gets propagated in the end of each query. As a result, this method will + * return 0 (1970-01-01T00:00:00) for the first micro-batch. If you use this value as a part + * of the timestamp set in the `setTimeoutTimestamp`, it may lead to the state expiring + * immediately in the next micro-batch, once the watermark gets the real value from your data. */ @throws[UnsupportedOperationException]( "if watermark has not been set before in [map|flatMap]GroupsWithState") def getCurrentWatermarkMs(): Long - /** * Get the current processing time as milliseconds in epoch time. - * @note In a streaming query, this will return a constant value throughout the duration of a - * trigger, even if the trigger is re-executed. + * @note + * In a streaming query, this will return a constant value throughout the duration of a + * trigger, even if the trigger is re-executed. */ def getCurrentProcessingTimeMs(): Long } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala index 0e2d6cc3778c6..568578d1f4ff6 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala @@ -21,8 +21,7 @@ import org.apache.spark.annotation.{Evolving, Experimental} @Experimental @Evolving /** - * Interface used for arbitrary stateful operations with the v2 API to capture - * list value state. + * Interface used for arbitrary stateful operations with the v2 API to capture list value state. */ private[sql] trait ListState[S] extends Serializable { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/MapState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/MapState.scala index 030c3ee989c6f..7b01888bbac49 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/MapState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/MapState.scala @@ -21,10 +21,10 @@ import org.apache.spark.annotation.{Evolving, Experimental} @Experimental @Evolving /** - * Interface used for arbitrary stateful operations with the v2 API to capture - * map value state. + * Interface used for arbitrary stateful operations with the v2 API to capture map value state. */ trait MapState[K, V] extends Serializable { + /** Whether state exists or not. */ def exists(): Boolean @@ -35,7 +35,7 @@ trait MapState[K, V] extends Serializable { def containsKey(key: K): Boolean /** Update value for given user key */ - def updateValue(key: K, value: V) : Unit + def updateValue(key: K, value: V): Unit /** Get the map associated with grouping key */ def iterator(): Iterator[(K, V)] diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/QueryInfo.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/QueryInfo.scala index 7754a514fdd6b..f239bcff49fea 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/QueryInfo.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/QueryInfo.scala @@ -22,8 +22,8 @@ import java.util.UUID import org.apache.spark.annotation.{Evolving, Experimental} /** - * Represents the query info provided to the stateful processor used in the - * arbitrary state API v2 to easily identify task retries on the same partition. + * Represents the query info provided to the stateful processor used in the arbitrary state API v2 + * to easily identify task retries on the same partition. */ @Experimental @Evolving diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala index 54e6a9a4ab678..d2c6010454c55 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala @@ -31,31 +31,35 @@ import org.apache.spark.sql.errors.ExecutionErrors private[sql] abstract class StatefulProcessor[K, I, O] extends Serializable { /** - * Handle to the stateful processor that provides access to the state store and other - * stateful processing related APIs. + * Handle to the stateful processor that provides access to the state store and other stateful + * processing related APIs. */ private var statefulProcessorHandle: StatefulProcessorHandle = null /** - * Function that will be invoked as the first method that allows for users to - * initialize all their state variables and perform other init actions before handling data. - * @param outputMode - output mode for the stateful processor - * @param timeMode - time mode for the stateful processor. + * Function that will be invoked as the first method that allows for users to initialize all + * their state variables and perform other init actions before handling data. + * @param outputMode + * \- output mode for the stateful processor + * @param timeMode + * \- time mode for the stateful processor. */ - def init( - outputMode: OutputMode, - timeMode: TimeMode): Unit + def init(outputMode: OutputMode, timeMode: TimeMode): Unit /** * Function that will allow users to interact with input data rows along with the grouping key * and current timer values and optionally provide output rows. - * @param key - grouping key - * @param inputRows - iterator of input rows associated with grouping key - * @param timerValues - instance of TimerValues that provides access to current processing/event - * time if available - * @param expiredTimerInfo - instance of ExpiredTimerInfo that provides access to expired timer - * if applicable - * @return - Zero or more output rows + * @param key + * \- grouping key + * @param inputRows + * \- iterator of input rows associated with grouping key + * @param timerValues + * \- instance of TimerValues that provides access to current processing/event time if + * available + * @param expiredTimerInfo + * \- instance of ExpiredTimerInfo that provides access to expired timer if applicable + * @return + * \- Zero or more output rows */ def handleInputRows( key: K, @@ -64,16 +68,17 @@ private[sql] abstract class StatefulProcessor[K, I, O] extends Serializable { expiredTimerInfo: ExpiredTimerInfo): Iterator[O] /** - * Function called as the last method that allows for users to perform - * any cleanup or teardown operations. + * Function called as the last method that allows for users to perform any cleanup or teardown + * operations. */ - def close (): Unit = {} + def close(): Unit = {} /** * Function to set the stateful processor handle that will be used to interact with the state * store and other stateful processor related operations. * - * @param handle - instance of StatefulProcessorHandle + * @param handle + * \- instance of StatefulProcessorHandle */ final def setHandle(handle: StatefulProcessorHandle): Unit = { statefulProcessorHandle = handle @@ -82,7 +87,8 @@ private[sql] abstract class StatefulProcessor[K, I, O] extends Serializable { /** * Function to get the stateful processor handle that will be used to interact with the state * - * @return handle - instance of StatefulProcessorHandle + * @return + * handle - instance of StatefulProcessorHandle */ final def getHandle: StatefulProcessorHandle = { if (statefulProcessorHandle == null) { @@ -93,23 +99,25 @@ private[sql] abstract class StatefulProcessor[K, I, O] extends Serializable { } /** - * Stateful processor with support for specifying initial state. - * Accepts a user-defined type as initial state to be initialized in the first batch. - * This can be used for starting a new streaming query with existing state from a - * previous streaming query. + * Stateful processor with support for specifying initial state. Accepts a user-defined type as + * initial state to be initialized in the first batch. This can be used for starting a new + * streaming query with existing state from a previous streaming query. */ @Experimental @Evolving private[sql] abstract class StatefulProcessorWithInitialState[K, I, O, S] - extends StatefulProcessor[K, I, O] { + extends StatefulProcessor[K, I, O] { /** * Function that will be invoked only in the first batch for users to process initial states. * - * @param key - grouping key - * @param initialState - A row in the initial state to be processed - * @param timerValues - instance of TimerValues that provides access to current processing/event - * time if available + * @param key + * \- grouping key + * @param initialState + * \- A row in the initial state to be processed + * @param timerValues + * \- instance of TimerValues that provides access to current processing/event time if + * available */ def handleInitialState(key: K, initialState: S, timerValues: TimerValues): Unit } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala index 4dc2ca875ef0e..d1eca0f3967d9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala @@ -22,39 +22,46 @@ import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.sql.Encoder /** - * Represents the operation handle provided to the stateful processor used in the - * arbitrary state API v2. + * Represents the operation handle provided to the stateful processor used in the arbitrary state + * API v2. */ @Experimental @Evolving private[sql] trait StatefulProcessorHandle extends Serializable { /** - * Function to create new or return existing single value state variable of given type. - * The user must ensure to call this function only within the `init()` method of the - * StatefulProcessor. + * Function to create new or return existing single value state variable of given type. The user + * must ensure to call this function only within the `init()` method of the StatefulProcessor. * - * @param stateName - name of the state variable - * @param valEncoder - SQL encoder for state variable - * @tparam T - type of state variable - * @return - instance of ValueState of type T that can be used to store state persistently + * @param stateName + * \- name of the state variable + * @param valEncoder + * \- SQL encoder for state variable + * @tparam T + * \- type of state variable + * @return + * \- instance of ValueState of type T that can be used to store state persistently */ def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] /** - * Function to create new or return existing single value state variable of given type - * with ttl. State values will not be returned past ttlDuration, and will be eventually removed - * from the state store. Any state update resets the ttl to current processing time plus - * ttlDuration. + * Function to create new or return existing single value state variable of given type with ttl. + * State values will not be returned past ttlDuration, and will be eventually removed from the + * state store. Any state update resets the ttl to current processing time plus ttlDuration. * * The user must ensure to call this function only within the `init()` method of the * StatefulProcessor. * - * @param stateName - name of the state variable - * @param valEncoder - SQL encoder for state variable - * @param ttlConfig - the ttl configuration (time to live duration etc.) - * @tparam T - type of state variable - * @return - instance of ValueState of type T that can be used to store state persistently + * @param stateName + * \- name of the state variable + * @param valEncoder + * \- SQL encoder for state variable + * @param ttlConfig + * \- the ttl configuration (time to live duration etc.) + * @tparam T + * \- type of state variable + * @return + * \- instance of ValueState of type T that can be used to store state persistently */ def getValueState[T]( stateName: String, @@ -62,30 +69,39 @@ private[sql] trait StatefulProcessorHandle extends Serializable { ttlConfig: TTLConfig): ValueState[T] /** - * Creates new or returns existing list state associated with stateName. - * The ListState persists values of type T. + * Creates new or returns existing list state associated with stateName. The ListState persists + * values of type T. * - * @param stateName - name of the state variable - * @param valEncoder - SQL encoder for state variable - * @tparam T - type of state variable - * @return - instance of ListState of type T that can be used to store state persistently + * @param stateName + * \- name of the state variable + * @param valEncoder + * \- SQL encoder for state variable + * @tparam T + * \- type of state variable + * @return + * \- instance of ListState of type T that can be used to store state persistently */ def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] /** - * Function to create new or return existing list state variable of given type - * with ttl. State values will not be returned past ttlDuration, and will be eventually removed - * from the state store. Any values in listState which have expired after ttlDuration will not - * be returned on get() and will be eventually removed from the state. + * Function to create new or return existing list state variable of given type with ttl. State + * values will not be returned past ttlDuration, and will be eventually removed from the state + * store. Any values in listState which have expired after ttlDuration will not be returned on + * get() and will be eventually removed from the state. * * The user must ensure to call this function only within the `init()` method of the * StatefulProcessor. * - * @param stateName - name of the state variable - * @param valEncoder - SQL encoder for state variable - * @param ttlConfig - the ttl configuration (time to live duration etc.) - * @tparam T - type of state variable - * @return - instance of ListState of type T that can be used to store state persistently + * @param stateName + * \- name of the state variable + * @param valEncoder + * \- SQL encoder for state variable + * @param ttlConfig + * \- the ttl configuration (time to live duration etc.) + * @tparam T + * \- type of state variable + * @return + * \- instance of ListState of type T that can be used to store state persistently */ def getListState[T]( stateName: String, @@ -93,15 +109,21 @@ private[sql] trait StatefulProcessorHandle extends Serializable { ttlConfig: TTLConfig): ListState[T] /** - * Creates new or returns existing map state associated with stateName. - * The MapState persists Key-Value pairs of type [K, V]. + * Creates new or returns existing map state associated with stateName. The MapState persists + * Key-Value pairs of type [K, V]. * - * @param stateName - name of the state variable - * @param userKeyEnc - spark sql encoder for the map key - * @param valEncoder - spark sql encoder for the map value - * @tparam K - type of key for map state variable - * @tparam V - type of value for map state variable - * @return - instance of MapState of type [K,V] that can be used to store state persistently + * @param stateName + * \- name of the state variable + * @param userKeyEnc + * \- spark sql encoder for the map key + * @param valEncoder + * \- spark sql encoder for the map value + * @tparam K + * \- type of key for map state variable + * @tparam V + * \- type of value for map state variable + * @return + * \- instance of MapState of type [K,V] that can be used to store state persistently */ def getMapState[K, V]( stateName: String, @@ -109,57 +131,68 @@ private[sql] trait StatefulProcessorHandle extends Serializable { valEncoder: Encoder[V]): MapState[K, V] /** - * Function to create new or return existing map state variable of given type - * with ttl. State values will not be returned past ttlDuration, and will be eventually removed - * from the state store. Any values in mapState which have expired after ttlDuration will not - * returned on get() and will be eventually removed from the state. + * Function to create new or return existing map state variable of given type with ttl. State + * values will not be returned past ttlDuration, and will be eventually removed from the state + * store. Any values in mapState which have expired after ttlDuration will not returned on get() + * and will be eventually removed from the state. * * The user must ensure to call this function only within the `init()` method of the * StatefulProcessor. * - * @param stateName - name of the state variable - * @param userKeyEnc - spark sql encoder for the map key - * @param valEncoder - SQL encoder for state variable - * @param ttlConfig - the ttl configuration (time to live duration etc.) - * @tparam K - type of key for map state variable - * @tparam V - type of value for map state variable - * @return - instance of MapState of type [K,V] that can be used to store state persistently + * @param stateName + * \- name of the state variable + * @param userKeyEnc + * \- spark sql encoder for the map key + * @param valEncoder + * \- SQL encoder for state variable + * @param ttlConfig + * \- the ttl configuration (time to live duration etc.) + * @tparam K + * \- type of key for map state variable + * @tparam V + * \- type of value for map state variable + * @return + * \- instance of MapState of type [K,V] that can be used to store state persistently */ def getMapState[K, V]( - stateName: String, - userKeyEnc: Encoder[K], - valEncoder: Encoder[V], - ttlConfig: TTLConfig): MapState[K, V] + stateName: String, + userKeyEnc: Encoder[K], + valEncoder: Encoder[V], + ttlConfig: TTLConfig): MapState[K, V] /** Function to return queryInfo for currently running task */ def getQueryInfo(): QueryInfo /** - * Function to register a processing/event time based timer for given implicit grouping key - * and provided timestamp - * @param expiryTimestampMs - timer expiry timestamp in milliseconds + * Function to register a processing/event time based timer for given implicit grouping key and + * provided timestamp + * @param expiryTimestampMs + * \- timer expiry timestamp in milliseconds */ def registerTimer(expiryTimestampMs: Long): Unit /** - * Function to delete a processing/event time based timer for given implicit grouping key - * and provided timestamp - * @param expiryTimestampMs - timer expiry timestamp in milliseconds + * Function to delete a processing/event time based timer for given implicit grouping key and + * provided timestamp + * @param expiryTimestampMs + * \- timer expiry timestamp in milliseconds */ def deleteTimer(expiryTimestampMs: Long): Unit /** - * Function to list all the timers registered for given implicit grouping key - * Note: calling listTimers() within the `handleInputRows` method of the StatefulProcessor - * will return all the unprocessed registered timers, including the one being fired within the - * invocation of `handleInputRows`. - * @return - list of all the registered timers for given implicit grouping key + * Function to list all the timers registered for given implicit grouping key Note: calling + * listTimers() within the `handleInputRows` method of the StatefulProcessor will return all the + * unprocessed registered timers, including the one being fired within the invocation of + * `handleInputRows`. + * @return + * \- list of all the registered timers for given implicit grouping key */ def listTimers(): Iterator[Long] /** * Function to delete and purge state variable if defined previously - * @param stateName - name of the state variable + * @param stateName + * \- name of the state variable */ def deleteIfExists(stateName: String): Unit } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala similarity index 63% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala rename to sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index fcb4bdcb327bc..a6684969ff1ec 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -20,19 +20,24 @@ package org.apache.spark.sql.streaming import java.util.UUID import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} +import com.fasterxml.jackson.databind.node.TreeTraversingParser import com.fasterxml.jackson.module.scala.{ClassTagExtensions, DefaultScalaModule} -import org.json4s.{JObject, JString, JValue} +import org.json4s.{JObject, JString} +import org.json4s.JsonAST.JValue import org.json4s.JsonDSL.{jobject2assoc, pair2Assoc} import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.annotation.Evolving +import org.apache.spark.scheduler.SparkListenerEvent /** - * Interface for listening to events related to [[StreamingQuery StreamingQueries]]. + * Interface for listening to events related to + * [[org.apache.spark.sql.api.StreamingQuery StreamingQueries]]. + * * @note * The methods are not thread-safe as they may be called from different threads. * - * @since 3.5.0 + * @since 2.0.0 */ @Evolving abstract class StreamingQueryListener extends Serializable { @@ -42,12 +47,11 @@ abstract class StreamingQueryListener extends Serializable { /** * Called when a query is started. * @note - * This is called synchronously with - * [[org.apache.spark.sql.streaming.DataStreamWriter `DataStreamWriter.start()`]], that is, - * `onQueryStart` will be called on all listeners before `DataStreamWriter.start()` returns - * the corresponding [[StreamingQuery]]. Please don't block this method as it will block your - * query. - * @since 3.5.0 + * This is called synchronously with `DataStreamWriter.start()`, that is, `onQueryStart` will + * be called on all listeners before `DataStreamWriter.start()` returns the corresponding + * [[org.apache.spark.sql.api.StreamingQuery]]. Please don't block this method as it will + * block your query. + * @since 2.0.0 */ def onQueryStarted(event: QueryStartedEvent): Unit @@ -55,11 +59,12 @@ abstract class StreamingQueryListener extends Serializable { * Called when there is some status update (ingestion rate updated, etc.) * * @note - * This method is asynchronous. The status in [[StreamingQuery]] will always be latest no - * matter when this method is called. Therefore, the status of [[StreamingQuery]] may be - * changed before/when you process the event. E.g., you may find [[StreamingQuery]] is - * terminated when you are processing `QueryProgressEvent`. - * @since 3.5.0 + * This method is asynchronous. The status in [[org.apache.spark.sql.api.StreamingQuery]] will + * always be latest no matter when this method is called. Therefore, the status of + * [[org.apache.spark.sql.api.StreamingQuery]] may be changed before/when you process the + * event. E.g., you may find [[org.apache.spark.sql.api.StreamingQuery]] is terminated when + * you are processing `QueryProgressEvent`. + * @since 2.0.0 */ def onQueryProgress(event: QueryProgressEvent): Unit @@ -71,24 +76,68 @@ abstract class StreamingQueryListener extends Serializable { /** * Called when a query is stopped, with or without error. - * @since 3.5.0 + * @since 2.0.0 */ def onQueryTerminated(event: QueryTerminatedEvent): Unit } +/** + * Py4J allows a pure interface so this proxy is required. + */ +private[spark] trait PythonStreamingQueryListener { + import StreamingQueryListener._ + + def onQueryStarted(event: QueryStartedEvent): Unit + + def onQueryProgress(event: QueryProgressEvent): Unit + + def onQueryIdle(event: QueryIdleEvent): Unit + + def onQueryTerminated(event: QueryTerminatedEvent): Unit +} + +private[spark] class PythonStreamingQueryListenerWrapper(listener: PythonStreamingQueryListener) + extends StreamingQueryListener { + import StreamingQueryListener._ + + def onQueryStarted(event: QueryStartedEvent): Unit = listener.onQueryStarted(event) + + def onQueryProgress(event: QueryProgressEvent): Unit = listener.onQueryProgress(event) + + override def onQueryIdle(event: QueryIdleEvent): Unit = listener.onQueryIdle(event) + + def onQueryTerminated(event: QueryTerminatedEvent): Unit = listener.onQueryTerminated(event) +} + /** * Companion object of [[StreamingQueryListener]] that defines the listener events. - * @since 3.5.0 + * @since 2.0.0 */ @Evolving object StreamingQueryListener extends Serializable { + private val mapper = { + val ret = new ObjectMapper() with ClassTagExtensions + ret.registerModule(DefaultScalaModule) + ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + ret + } + + private case class EventParser(json: String) { + private val tree = mapper.readTree(json) + def getString(name: String): String = tree.get(name).asText() + def getUUID(name: String): UUID = UUID.fromString(getString(name)) + def getProgress(name: String): StreamingQueryProgress = { + val parser = new TreeTraversingParser(tree.get(name), mapper) + parser.readValueAs(classOf[StreamingQueryProgress]) + } + } /** * Base type of [[StreamingQueryListener]] events - * @since 3.5.0 + * @since 2.0.0 */ @Evolving - trait Event + trait Event extends SparkListenerEvent /** * Event representing the start of a query @@ -100,7 +149,7 @@ object StreamingQueryListener extends Serializable { * User-specified name of the query, null if not specified. * @param timestamp * The timestamp to start a query. - * @since 3.5.0 + * @since 2.1.0 */ @Evolving class QueryStartedEvent private[sql] ( @@ -122,25 +171,22 @@ object StreamingQueryListener extends Serializable { } private[spark] object QueryStartedEvent { - private val mapper = { - val ret = new ObjectMapper() with ClassTagExtensions - ret.registerModule(DefaultScalaModule) - ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) - ret - } - private[spark] def jsonString(event: QueryStartedEvent): String = - mapper.writeValueAsString(event) - - private[spark] def fromJson(json: String): QueryStartedEvent = - mapper.readValue[QueryStartedEvent](json) + private[spark] def fromJson(json: String): QueryStartedEvent = { + val parser = EventParser(json) + new QueryStartedEvent( + parser.getUUID("id"), + parser.getUUID("runId"), + parser.getString("name"), + parser.getString("name")) + } } /** * Event representing any progress updates in a query. * @param progress * The query progress updates. - * @since 3.5.0 + * @since 2.1.0 */ @Evolving class QueryProgressEvent private[sql] (val progress: StreamingQueryProgress) @@ -153,18 +199,11 @@ object StreamingQueryListener extends Serializable { } private[spark] object QueryProgressEvent { - private val mapper = { - val ret = new ObjectMapper() with ClassTagExtensions - ret.registerModule(DefaultScalaModule) - ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) - ret - } - - private[spark] def jsonString(event: QueryProgressEvent): String = - mapper.writeValueAsString(event) - private[spark] def fromJson(json: String): QueryProgressEvent = - mapper.readValue[QueryProgressEvent](json) + private[spark] def fromJson(json: String): QueryProgressEvent = { + val parser = EventParser(json) + new QueryProgressEvent(parser.getProgress("progress")) + } } /** @@ -193,18 +232,14 @@ object StreamingQueryListener extends Serializable { } private[spark] object QueryIdleEvent { - private val mapper = { - val ret = new ObjectMapper() with ClassTagExtensions - ret.registerModule(DefaultScalaModule) - ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) - ret - } - private[spark] def jsonString(event: QueryTerminatedEvent): String = - mapper.writeValueAsString(event) - - private[spark] def fromJson(json: String): QueryTerminatedEvent = - mapper.readValue[QueryTerminatedEvent](json) + private[spark] def fromJson(json: String): QueryIdleEvent = { + val parser = EventParser(json) + new QueryIdleEvent( + parser.getUUID("id"), + parser.getUUID("runId"), + parser.getString("timestamp")) + } } /** @@ -221,7 +256,7 @@ object StreamingQueryListener extends Serializable { * The error class from the exception if the query was terminated with an exception which is a * part of error class framework. If the query was terminated without an exception, or the * exception is not a part of error class framework, it will be `None`. - * @since 3.5.0 + * @since 2.1.0 */ @Evolving class QueryTerminatedEvent private[sql] ( @@ -247,17 +282,13 @@ object StreamingQueryListener extends Serializable { } private[spark] object QueryTerminatedEvent { - private val mapper = { - val ret = new ObjectMapper() with ClassTagExtensions - ret.registerModule(DefaultScalaModule) - ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) - ret + private[spark] def fromJson(json: String): QueryTerminatedEvent = { + val parser = EventParser(json) + new QueryTerminatedEvent( + parser.getUUID("id"), + parser.getUUID("runId"), + Option(parser.getString("exception")), + Option(parser.getString("errorClassOnException"))) } - - private[spark] def jsonString(event: QueryTerminatedEvent): String = - mapper.writeValueAsString(event) - - private[spark] def fromJson(json: String): QueryTerminatedEvent = - mapper.readValue[QueryTerminatedEvent](json) } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala similarity index 95% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala rename to sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala index cdda25876b250..c37cdd00c8866 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala @@ -36,7 +36,7 @@ import org.apache.spark.annotation.Evolving * True when the trigger is actively firing, false when waiting for the next trigger time. * Doesn't apply to ContinuousExecution where it is always false. * - * @since 3.5.0 + * @since 2.1.0 */ @Evolving class StreamingQueryStatus protected[sql] ( @@ -44,7 +44,6 @@ class StreamingQueryStatus protected[sql] ( val isDataAvailable: Boolean, val isTriggerActive: Boolean) extends Serializable { - // This is a copy of the same class in sql/core/.../streaming/StreamingQueryStatus.scala /** The compact JSON representation of this status. */ def json: String = compact(render(jsonValue)) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala index 576e09d5d7fe2..ce786aa943d89 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.streaming import java.time.Duration /** - * TTL Configuration for state variable. State values will not be returned past ttlDuration, - * and will be eventually removed from the state store. Any state update resets the ttl to - * current processing time plus ttlDuration. + * TTL Configuration for state variable. State values will not be returned past ttlDuration, and + * will be eventually removed from the state store. Any state update resets the ttl to current + * processing time plus ttlDuration. * - * @param ttlDuration time to live duration for state - * stored in the state variable. + * @param ttlDuration + * time to live duration for state stored in the state variable. */ case class TTLConfig(ttlDuration: Duration) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/TimerValues.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/TimerValues.scala index f0aef58228188..04c5f59428f7f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/TimerValues.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/TimerValues.scala @@ -22,25 +22,29 @@ import java.io.Serializable import org.apache.spark.annotation.{Evolving, Experimental} /** - * Class used to provide access to timer values for processing and event time populated - * before method invocations using the arbitrary state API v2. + * Class used to provide access to timer values for processing and event time populated before + * method invocations using the arbitrary state API v2. */ @Experimental @Evolving private[sql] trait TimerValues extends Serializable { + /** * Get the current processing time as milliseconds in epoch time. - * @note This will return a constant value throughout the duration of a streaming query trigger, - * even if the trigger is re-executed. + * @note + * This will return a constant value throughout the duration of a streaming query trigger, + * even if the trigger is re-executed. */ def getCurrentProcessingTimeInMs(): Long /** * Get the current event time watermark as milliseconds in epoch time. * - * @note This can be called only when watermark is set before calling `transformWithState`. - * @note The watermark gets propagated at the end of each query. As a result, this method will - * return 0 (1970-01-01T00:00:00) for the first micro-batch. + * @note + * This can be called only when watermark is set before calling `transformWithState`. + * @note + * The watermark gets propagated at the end of each query. As a result, this method will + * return 0 (1970-01-01T00:00:00) for the first micro-batch. */ def getCurrentWatermarkInMs(): Long } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala index 8a2661e1a55b1..edb5f65365ab8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala @@ -24,8 +24,7 @@ import org.apache.spark.annotation.{Evolving, Experimental} @Experimental @Evolving /** - * Interface used for arbitrary stateful operations with the v2 API to capture - * single value state. + * Interface used for arbitrary stateful operations with the v2 API to capture single value state. */ private[sql] trait ValueState[S] extends Serializable { @@ -34,7 +33,8 @@ private[sql] trait ValueState[S] extends Serializable { /** * Get the state value if it exists - * @throws java.util.NoSuchElementException if the state does not exist + * @throws java.util.NoSuchElementException + * if the state does not exist */ @throws[NoSuchElementException] def get(): S @@ -45,7 +45,8 @@ private[sql] trait ValueState[S] extends Serializable { /** * Update the value of the state. * - * @param newState the new value + * @param newState + * the new value */ def update(newState: S): Unit diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/progress.scala similarity index 99% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/progress.scala rename to sql/api/src/main/scala/org/apache/spark/sql/streaming/progress.scala index ebd13bc248f97..b7573cb280444 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -139,7 +139,7 @@ class StateOperatorProgress private[spark] ( * Information about operators in the query that store state. * @param sources * detailed statistics on data being read from each of the streaming sources. - * @since 3.5.0 + * @since 2.1.0 */ @Evolving class StreamingQueryProgress private[spark] ( @@ -195,7 +195,7 @@ class StreamingQueryProgress private[spark] ( } private[spark] object StreamingQueryProgress { - private val mapper = { + private[this] val mapper = { val ret = new ObjectMapper() with ClassTagExtensions ret.registerModule(DefaultScalaModule) ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) @@ -227,7 +227,7 @@ private[spark] object StreamingQueryProgress { * The rate at which data is arriving from this source. * @param processedRowsPerSecond * The rate at which data from this source is being processed by Spark. - * @since 3.5.0 + * @since 2.1.0 */ @Evolving class SourceProgress protected[spark] ( @@ -276,7 +276,7 @@ class SourceProgress protected[spark] ( * @param numOutputRows * Number of rows written to the sink or -1 for Continuous Mode (temporarily) or Sink V1 (until * decommissioned). - * @since 3.5.0 + * @since 2.1.0 */ @Evolving class SinkProgress protected[spark] ( diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 67f634f8379cd..9590fb23e16b1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.errors.DataTypeErrors * A non-concrete data type, reserved for internal uses. */ private[sql] abstract class AbstractDataType { + /** * The default concrete type to use if we want to cast a null literal into this type. */ @@ -47,7 +48,6 @@ private[sql] abstract class AbstractDataType { private[sql] def simpleString: String } - /** * A collection of types that can be used to specify type constraints. The sequence also specifies * precedence: an earlier type takes precedence over a latter type. @@ -59,7 +59,7 @@ private[sql] abstract class AbstractDataType { * This means that we prefer StringType over BinaryType if it is possible to cast to StringType. */ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) - extends AbstractDataType { + extends AbstractDataType { require(types.nonEmpty, s"TypeCollection ($types) cannot be empty") @@ -73,22 +73,20 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) } } - private[sql] object TypeCollection { /** * Types that include numeric types and ANSI interval types. */ - val NumericAndAnsiInterval = TypeCollection( - NumericType, - DayTimeIntervalType, - YearMonthIntervalType) + val NumericAndAnsiInterval = + TypeCollection(NumericType, DayTimeIntervalType, YearMonthIntervalType) /** - * Types that include numeric and ANSI interval types, and additionally the legacy interval type. - * They are only used in unary_minus, unary_positive, add and subtract operations. + * Types that include numeric and ANSI interval types, and additionally the legacy interval + * type. They are only used in unary_minus, unary_positive, add and subtract operations. */ - val NumericAndInterval = new TypeCollection(NumericAndAnsiInterval.types :+ CalendarIntervalType) + val NumericAndInterval = new TypeCollection( + NumericAndAnsiInterval.types :+ CalendarIntervalType) def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) @@ -98,7 +96,6 @@ private[sql] object TypeCollection { } } - /** * An `AbstractDataType` that matches any concrete data types. */ @@ -114,15 +111,14 @@ protected[sql] object AnyDataType extends AbstractDataType with Serializable { override private[sql] def acceptsType(other: DataType): Boolean = true } - /** - * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. + * An internal type used to represent everything that is not null, UDTs, arrays, structs, and + * maps. */ protected[sql] abstract class AtomicType extends DataType object AtomicType - /** * Numeric data types. * @@ -131,7 +127,6 @@ object AtomicType @Stable abstract class NumericType extends AtomicType - private[spark] object NumericType extends AbstractDataType { override private[spark] def defaultConcreteType: DataType = DoubleType @@ -141,22 +136,19 @@ private[spark] object NumericType extends AbstractDataType { other.isInstanceOf[NumericType] } - private[sql] object IntegralType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = IntegerType override private[sql] def simpleString: String = "integral" - override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[IntegralType] + override private[sql] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[IntegralType] } - private[sql] abstract class IntegralType extends NumericType - private[sql] object FractionalType - private[sql] abstract class FractionalType extends NumericType private[sql] object AnyTimestampType extends AbstractDataType with Serializable { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index e5af472d90e25..fc32248b4baf3 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -29,12 +29,14 @@ import org.apache.spark.sql.catalyst.util.StringConcat */ @Stable object ArrayType extends AbstractDataType { + /** * Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true) - override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) + override private[sql] def defaultConcreteType: DataType = + ArrayType(NullType, containsNull = true) override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[ArrayType] @@ -44,18 +46,19 @@ object ArrayType extends AbstractDataType { } /** - * The data type for collections of multiple values. - * Internally these are represented as columns that contain a ``scala.collection.Seq``. + * The data type for collections of multiple values. Internally these are represented as columns + * that contain a ``scala.collection.Seq``. * * Please use `DataTypes.createArrayType()` to create a specific instance. * - * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and - * `containsNull: Boolean`. - * The field of `elementType` is used to specify the type of array elements. - * The field of `containsNull` is used to specify if the array can have `null` values. + * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and `containsNull: + * Boolean`. The field of `elementType` is used to specify the type of array elements. The field + * of `containsNull` is used to specify if the array can have `null` values. * - * @param elementType The data type of values. - * @param containsNull Indicates if the array can have `null` values + * @param elementType + * The data type of values. + * @param containsNull + * Indicates if the array can have `null` values * * @since 1.3.0 */ @@ -82,8 +85,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT ("containsNull" -> containsNull) /** - * The default size of a value of the ArrayType is the default size of the element type. - * We assume that there is only 1 element on average in an array. See SPARK-18853. + * The default size of a value of the ArrayType is the default size of the element type. We + * assume that there is only 1 element on average in an array. See SPARK-18853. */ override def defaultSize: Int = 1 * elementType.defaultSize @@ -97,8 +100,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT ArrayType(elementType.asNullable, containsNull = true) /** - * Returns the same data type but set all nullability fields are true - * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). + * Returns the same data type but set all nullability fields are true (`StructField.nullable`, + * `ArrayType.containsNull`, and `MapType.valueContainsNull`). * * @since 4.0.0 */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/BinaryType.scala index c280f66f943aa..20bfd9bf5684f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/BinaryType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.types import org.apache.spark.annotation.Stable /** - * The data type representing `Array[Byte]` values. - * Please use the singleton `DataTypes.BinaryType`. + * The data type representing `Array[Byte]` values. Please use the singleton + * `DataTypes.BinaryType`. */ @Stable -class BinaryType private() extends AtomicType { +class BinaryType private () extends AtomicType { + /** * The default size of a value of the BinaryType is 100 bytes. */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/BooleanType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/BooleanType.scala index 836c41a996ac4..090c56eaf9af7 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/BooleanType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/BooleanType.scala @@ -25,7 +25,8 @@ import org.apache.spark.annotation.Stable * @since 1.3.0 */ @Stable -class BooleanType private() extends AtomicType { +class BooleanType private () extends AtomicType { + /** * The default size of a value of the BooleanType is 1 byte. */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ByteType.scala index 546ac02f2639a..4a27a00dacb8a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/ByteType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/ByteType.scala @@ -25,7 +25,8 @@ import org.apache.spark.annotation.Stable * @since 1.3.0 */ @Stable -class ByteType private() extends IntegralType { +class ByteType private () extends IntegralType { + /** * The default size of a value of the ByteType is 1 byte. */ @@ -36,7 +37,6 @@ class ByteType private() extends IntegralType { private[spark] override def asNullable: ByteType = this } - /** * @since 1.3.0 */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala index d506a1521e183..f6b6256db0417 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala @@ -21,19 +21,19 @@ import org.apache.spark.annotation.Stable /** * The data type representing calendar intervals. The calendar interval is stored internally in - * three components: - * an integer value representing the number of `months` in this interval, - * an integer value representing the number of `days` in this interval, - * a long value representing the number of `microseconds` in this interval. + * three components: an integer value representing the number of `months` in this interval, an + * integer value representing the number of `days` in this interval, a long value representing the + * number of `microseconds` in this interval. * * Please use the singleton `DataTypes.CalendarIntervalType` to refer the type. * - * @note Calendar intervals are not comparable. + * @note + * Calendar intervals are not comparable. * * @since 1.5.0 */ @Stable -class CalendarIntervalType private() extends DataType { +class CalendarIntervalType private () extends DataType { override def defaultSize: Int = 16 diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index 277d5c9458d6f..008c9cd07076c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -49,6 +49,7 @@ import org.apache.spark.util.SparkClassUtils @JsonSerialize(using = classOf[DataTypeJsonSerializer]) @JsonDeserialize(using = classOf[DataTypeJsonDeserializer]) abstract class DataType extends AbstractDataType { + /** * The default size of a value of this data type, used internally for size estimation. */ @@ -57,7 +58,9 @@ abstract class DataType extends AbstractDataType { /** Name of the type used in JSON serialization. */ def typeName: String = { this.getClass.getSimpleName - .stripSuffix("$").stripSuffix("Type").stripSuffix("UDT") + .stripSuffix("$") + .stripSuffix("Type") + .stripSuffix("UDT") .toLowerCase(Locale.ROOT) } @@ -92,8 +95,8 @@ abstract class DataType extends AbstractDataType { } /** - * Returns the same data type but set all nullability fields are true - * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). + * Returns the same data type but set all nullability fields are true (`StructField.nullable`, + * `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def asNullable: DataType @@ -107,7 +110,6 @@ abstract class DataType extends AbstractDataType { override private[sql] def acceptsType(other: DataType): Boolean = sameType(other) } - /** * @since 1.3.0 */ @@ -128,14 +130,18 @@ object DataType { } /** - * Parses data type from a string with schema. It calls `parser` for `schema`. - * If it fails, calls `fallbackParser`. If the fallback function fails too, combines error message - * from `parser` and `fallbackParser`. + * Parses data type from a string with schema. It calls `parser` for `schema`. If it fails, + * calls `fallbackParser`. If the fallback function fails too, combines error message from + * `parser` and `fallbackParser`. * - * @param schema The schema string to parse by `parser` or `fallbackParser`. - * @param parser The function that should be invoke firstly. - * @param fallbackParser The function that is called when `parser` fails. - * @return The data type parsed from the `schema` schema. + * @param schema + * The schema string to parse by `parser` or `fallbackParser`. + * @param parser + * The function that should be invoke firstly. + * @param fallbackParser + * The function that is called when `parser` fails. + * @return + * The data type parsed from the `schema` schema. */ def parseTypeWithFallback( schema: String, @@ -161,8 +167,20 @@ object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) private val otherTypes = { - Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType, - DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType, + Seq( + NullType, + DateType, + TimestampType, + BinaryType, + IntegerType, + BooleanType, + LongType, + DoubleType, + FloatType, + ShortType, + ByteType, + StringType, + CalendarIntervalType, DayTimeIntervalType(DAY), DayTimeIntervalType(DAY, HOUR), DayTimeIntervalType(DAY, MINUTE), @@ -178,7 +196,8 @@ object DataType { YearMonthIntervalType(YEAR, MONTH), TimestampNTZType, VariantType) - .map(t => t.typeName -> t).toMap + .map(t => t.typeName -> t) + .toMap } /** Given the string representation of a type, return its DataType */ @@ -191,11 +210,12 @@ object DataType { // For backwards compatibility, previously the type name of NullType is "null" case "null" => NullType case "timestamp_ltz" => TimestampType - case other => otherTypes.getOrElse( - other, - throw new SparkIllegalArgumentException( - errorClass = "INVALID_JSON_DATA_TYPE", - messageParameters = Map("invalidType" -> name))) + case other => + otherTypes.getOrElse( + other, + throw new SparkIllegalArgumentException( + errorClass = "INVALID_JSON_DATA_TYPE", + messageParameters = Map("invalidType" -> name))) } } @@ -220,56 +240,55 @@ object DataType { } case JSortedObject( - ("containsNull", JBool(n)), - ("elementType", t: JValue), - ("type", JString("array"))) => + ("containsNull", JBool(n)), + ("elementType", t: JValue), + ("type", JString("array"))) => assertValidTypeForCollations(fieldPath, "array", collationsMap) val elementType = parseDataType(t, appendFieldToPath(fieldPath, "element"), collationsMap) ArrayType(elementType, n) case JSortedObject( - ("keyType", k: JValue), - ("type", JString("map")), - ("valueContainsNull", JBool(n)), - ("valueType", v: JValue)) => + ("keyType", k: JValue), + ("type", JString("map")), + ("valueContainsNull", JBool(n)), + ("valueType", v: JValue)) => assertValidTypeForCollations(fieldPath, "map", collationsMap) val keyType = parseDataType(k, appendFieldToPath(fieldPath, "key"), collationsMap) val valueType = parseDataType(v, appendFieldToPath(fieldPath, "value"), collationsMap) MapType(keyType, valueType, n) - case JSortedObject( - ("fields", JArray(fields)), - ("type", JString("struct"))) => + case JSortedObject(("fields", JArray(fields)), ("type", JString("struct"))) => assertValidTypeForCollations(fieldPath, "struct", collationsMap) StructType(fields.map(parseStructField)) // Scala/Java UDT case JSortedObject( - ("class", JString(udtClass)), - ("pyClass", _), - ("sqlType", _), - ("type", JString("udt"))) => + ("class", JString(udtClass)), + ("pyClass", _), + ("sqlType", _), + ("type", JString("udt"))) => SparkClassUtils.classForName[UserDefinedType[_]](udtClass).getConstructor().newInstance() // Python UDT case JSortedObject( - ("pyClass", JString(pyClass)), - ("serializedClass", JString(serialized)), - ("sqlType", v: JValue), - ("type", JString("udt"))) => - new PythonUserDefinedType(parseDataType(v), pyClass, serialized) - - case other => throw new SparkIllegalArgumentException( - errorClass = "INVALID_JSON_DATA_TYPE", - messageParameters = Map("invalidType" -> compact(render(other)))) + ("pyClass", JString(pyClass)), + ("serializedClass", JString(serialized)), + ("sqlType", v: JValue), + ("type", JString("udt"))) => + new PythonUserDefinedType(parseDataType(v), pyClass, serialized) + + case other => + throw new SparkIllegalArgumentException( + errorClass = "INVALID_JSON_DATA_TYPE", + messageParameters = Map("invalidType" -> compact(render(other)))) } private def parseStructField(json: JValue): StructField = json match { case JSortedObject( - ("metadata", JObject(metadataFields)), - ("name", JString(name)), - ("nullable", JBool(nullable)), - ("type", dataType: JValue)) => + ("metadata", JObject(metadataFields)), + ("name", JString(name)), + ("nullable", JBool(nullable)), + ("type", dataType: JValue)) => val collationsMap = getCollationsMap(metadataFields) val metadataWithoutCollations = JObject(metadataFields.filterNot(_._1 == COLLATIONS_METADATA_KEY)) @@ -280,18 +299,17 @@ object DataType { Metadata.fromJObject(metadataWithoutCollations)) // Support reading schema when 'metadata' is missing. case JSortedObject( - ("name", JString(name)), - ("nullable", JBool(nullable)), - ("type", dataType: JValue)) => + ("name", JString(name)), + ("nullable", JBool(nullable)), + ("type", dataType: JValue)) => StructField(name, parseDataType(dataType), nullable) // Support reading schema when 'nullable' is missing. - case JSortedObject( - ("name", JString(name)), - ("type", dataType: JValue)) => + case JSortedObject(("name", JString(name)), ("type", dataType: JValue)) => StructField(name, parseDataType(dataType)) - case other => throw new SparkIllegalArgumentException( - errorClass = "INVALID_JSON_DATA_TYPE", - messageParameters = Map("invalidType" -> compact(render(other)))) + case other => + throw new SparkIllegalArgumentException( + errorClass = "INVALID_JSON_DATA_TYPE", + messageParameters = Map("invalidType" -> compact(render(other)))) } private def assertValidTypeForCollations( @@ -319,13 +337,12 @@ object DataType { val collationsJsonOpt = metadataFields.find(_._1 == COLLATIONS_METADATA_KEY).map(_._2) collationsJsonOpt match { case Some(JObject(fields)) => - fields.collect { - case (fieldPath, JString(collation)) => - collation.split("\\.", 2) match { - case Array(provider: String, collationName: String) => - CollationFactory.assertValidProvider(provider) - fieldPath -> collationName - } + fields.collect { case (fieldPath, JString(collation)) => + collation.split("\\.", 2) match { + case Array(provider: String, collationName: String) => + CollationFactory.assertValidProvider(provider) + fieldPath -> collationName + } }.toMap case _ => Map.empty @@ -356,15 +373,15 @@ object DataType { * Compares two types, ignoring compatible nullability of ArrayType, MapType, StructType. * * Compatible nullability is defined as follows: - * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to` - * if and only if `to.containsNull` is true, or both of `from.containsNull` and - * `to.containsNull` are false. - * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to` - * if and only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and - * `to.valueContainsNull` are false. - * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to` - * if and only if for all every pair of fields, `to.nullable` is true, or both - * of `fromField.nullable` and `toField.nullable` are false. + * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to` if and + * only if `to.containsNull` is true, or both of `from.containsNull` and `to.containsNull` + * are false. + * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to` if and + * only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and + * `to.valueContainsNull` are false. + * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to` if and + * only if for all every pair of fields, `to.nullable` is true, or both of + * `fromField.nullable` and `toField.nullable` are false. */ private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = { equalsIgnoreCompatibleNullability(from, to, ignoreName = false) @@ -375,15 +392,15 @@ object DataType { * also the field name. It compares based on the position. * * Compatible nullability is defined as follows: - * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to` - * if and only if `to.containsNull` is true, or both of `from.containsNull` and - * `to.containsNull` are false. - * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to` - * if and only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and - * `to.valueContainsNull` are false. - * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to` - * if and only if for all every pair of fields, `to.nullable` is true, or both - * of `fromField.nullable` and `toField.nullable` are false. + * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to` if and + * only if `to.containsNull` is true, or both of `from.containsNull` and `to.containsNull` + * are false. + * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to` if and + * only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and + * `to.valueContainsNull` are false. + * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to` if and + * only if for all every pair of fields, `to.nullable` is true, or both of + * `fromField.nullable` and `toField.nullable` are false. */ private[sql] def equalsIgnoreNameAndCompatibleNullability( from: DataType, @@ -401,16 +418,16 @@ object DataType { case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => (tn || !fn) && - equalsIgnoreCompatibleNullability(fromKey, toKey, ignoreName) && - equalsIgnoreCompatibleNullability(fromValue, toValue, ignoreName) + equalsIgnoreCompatibleNullability(fromKey, toKey, ignoreName) && + equalsIgnoreCompatibleNullability(fromValue, toValue, ignoreName) case (StructType(fromFields), StructType(toFields)) => fromFields.length == toFields.length && - fromFields.zip(toFields).forall { case (fromField, toField) => - (ignoreName || fromField.name == toField.name) && - (toField.nullable || !fromField.nullable) && - equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType, ignoreName) - } + fromFields.zip(toFields).forall { case (fromField, toField) => + (ignoreName || fromField.name == toField.name) && + (toField.nullable || !fromField.nullable) && + equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType, ignoreName) + } case (fromDataType, toDataType) => fromDataType == toDataType } @@ -420,42 +437,42 @@ object DataType { * Check if `from` is equal to `to` type except for collations, which are checked to be * compatible so that data of type `from` can be interpreted as of type `to`. */ - private[sql] def equalsIgnoreCompatibleCollation( - from: DataType, - to: DataType): Boolean = { + private[sql] def equalsIgnoreCompatibleCollation(from: DataType, to: DataType): Boolean = { (from, to) match { // String types with possibly different collations are compatible. case (_: StringType, _: StringType) => true case (ArrayType(fromElement, fromContainsNull), ArrayType(toElement, toContainsNull)) => (fromContainsNull == toContainsNull) && - equalsIgnoreCompatibleCollation(fromElement, toElement) + equalsIgnoreCompatibleCollation(fromElement, toElement) - case (MapType(fromKey, fromValue, fromContainsNull), - MapType(toKey, toValue, toContainsNull)) => + case ( + MapType(fromKey, fromValue, fromContainsNull), + MapType(toKey, toValue, toContainsNull)) => fromContainsNull == toContainsNull && - // Map keys cannot change collation. - fromKey == toKey && - equalsIgnoreCompatibleCollation(fromValue, toValue) + // Map keys cannot change collation. + fromKey == toKey && + equalsIgnoreCompatibleCollation(fromValue, toValue) case (StructType(fromFields), StructType(toFields)) => fromFields.length == toFields.length && - fromFields.zip(toFields).forall { case (fromField, toField) => - fromField.name == toField.name && - fromField.nullable == toField.nullable && - fromField.metadata == toField.metadata && - equalsIgnoreCompatibleCollation(fromField.dataType, toField.dataType) - } + fromFields.zip(toFields).forall { case (fromField, toField) => + fromField.name == toField.name && + fromField.nullable == toField.nullable && + fromField.metadata == toField.metadata && + equalsIgnoreCompatibleCollation(fromField.dataType, toField.dataType) + } case (fromDataType, toDataType) => fromDataType == toDataType } } /** - * Returns true if the two data types share the same "shape", i.e. the types - * are the same, but the field names don't need to be the same. + * Returns true if the two data types share the same "shape", i.e. the types are the same, but + * the field names don't need to be the same. * - * @param ignoreNullability whether to ignore nullability when comparing the types + * @param ignoreNullability + * whether to ignore nullability when comparing the types */ def equalsStructurally( from: DataType, @@ -464,20 +481,21 @@ object DataType { (from, to) match { case (left: ArrayType, right: ArrayType) => equalsStructurally(left.elementType, right.elementType, ignoreNullability) && - (ignoreNullability || left.containsNull == right.containsNull) + (ignoreNullability || left.containsNull == right.containsNull) case (left: MapType, right: MapType) => equalsStructurally(left.keyType, right.keyType, ignoreNullability) && - equalsStructurally(left.valueType, right.valueType, ignoreNullability) && - (ignoreNullability || left.valueContainsNull == right.valueContainsNull) + equalsStructurally(left.valueType, right.valueType, ignoreNullability) && + (ignoreNullability || left.valueContainsNull == right.valueContainsNull) case (StructType(fromFields), StructType(toFields)) => fromFields.length == toFields.length && - fromFields.zip(toFields) - .forall { case (l, r) => - equalsStructurally(l.dataType, r.dataType, ignoreNullability) && - (ignoreNullability || l.nullable == r.nullable) - } + fromFields + .zip(toFields) + .forall { case (l, r) => + equalsStructurally(l.dataType, r.dataType, ignoreNullability) && + (ignoreNullability || l.nullable == r.nullable) + } case (fromDataType, toDataType) => fromDataType == toDataType } @@ -496,14 +514,15 @@ object DataType { case (left: MapType, right: MapType) => equalsStructurallyByName(left.keyType, right.keyType, resolver) && - equalsStructurallyByName(left.valueType, right.valueType, resolver) + equalsStructurallyByName(left.valueType, right.valueType, resolver) case (StructType(fromFields), StructType(toFields)) => fromFields.length == toFields.length && - fromFields.zip(toFields) - .forall { case (l, r) => - resolver(l.name, r.name) && equalsStructurallyByName(l.dataType, r.dataType, resolver) - } + fromFields + .zip(toFields) + .forall { case (l, r) => + resolver(l.name, r.name) && equalsStructurallyByName(l.dataType, r.dataType, resolver) + } case _ => true } @@ -518,12 +537,12 @@ object DataType { equalsIgnoreNullability(leftElementType, rightElementType) case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) => equalsIgnoreNullability(leftKeyType, rightKeyType) && - equalsIgnoreNullability(leftValueType, rightValueType) + equalsIgnoreNullability(leftValueType, rightValueType) case (StructType(leftFields), StructType(rightFields)) => leftFields.length == rightFields.length && - leftFields.zip(rightFields).forall { case (l, r) => - l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType) - } + leftFields.zip(rightFields).forall { case (l, r) => + l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType) + } case (l, r) => l == r } } @@ -539,14 +558,14 @@ object DataType { case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) => equalsIgnoreCaseAndNullability(fromKey, toKey) && - equalsIgnoreCaseAndNullability(fromValue, toValue) + equalsIgnoreCaseAndNullability(fromValue, toValue) case (StructType(fromFields), StructType(toFields)) => fromFields.length == toFields.length && - fromFields.zip(toFields).forall { case (l, r) => - l.name.equalsIgnoreCase(r.name) && - equalsIgnoreCaseAndNullability(l.dataType, r.dataType) - } + fromFields.zip(toFields).forall { case (l, r) => + l.name.equalsIgnoreCase(r.name) && + equalsIgnoreCaseAndNullability(l.dataType, r.dataType) + } case (fromDataType, toDataType) => fromDataType == toDataType } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DateType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DateType.scala index d37ebbcdad727..402c4c0d95298 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DateType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DateType.scala @@ -20,14 +20,15 @@ package org.apache.spark.sql.types import org.apache.spark.annotation.Stable /** - * The date type represents a valid date in the proleptic Gregorian calendar. - * Valid range is [0001-01-01, 9999-12-31]. + * The date type represents a valid date in the proleptic Gregorian calendar. Valid range is + * [0001-01-01, 9999-12-31]. * * Please use the singleton `DataTypes.DateType` to refer the type. * @since 1.3.0 */ @Stable -class DateType private() extends DatetimeType { +class DateType private () extends DatetimeType { + /** * The default size of a value of the DateType is 4 bytes. */ @@ -37,10 +38,10 @@ class DateType private() extends DatetimeType { } /** - * The companion case object and the DateType class is separated so the companion object - * also subclasses the class. Otherwise, the companion object would be of type "DateType$" - * in byte code. The DateType class is defined with a private constructor so its companion - * object is the only possible instantiation. + * The companion case object and the DateType class is separated so the companion object also + * subclasses the class. Otherwise, the companion object would be of type "DateType$" in byte + * code. The DateType class is defined with a private constructor so its companion object is the + * only possible instantiation. * * @since 1.3.0 */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DayTimeIntervalType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DayTimeIntervalType.scala index a1d014fa51f1c..90d6d7c29a6ba 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DayTimeIntervalType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DayTimeIntervalType.scala @@ -22,8 +22,8 @@ import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.types.DayTimeIntervalType.fieldToString /** - * The type represents day-time intervals of the SQL standard. A day-time interval is made up - * of a contiguous subset of the following fields: + * The type represents day-time intervals of the SQL standard. A day-time interval is made up of a + * contiguous subset of the following fields: * - SECOND, seconds within minutes and possibly fractions of a second [0..59.999999], * - MINUTE, minutes within hours [0..59], * - HOUR, hours within days [0..23], @@ -31,18 +31,21 @@ import org.apache.spark.sql.types.DayTimeIntervalType.fieldToString * * `DayTimeIntervalType` represents positive as well as negative day-time intervals. * - * @param startField The leftmost field which the type comprises of. Valid values: - * 0 (DAY), 1 (HOUR), 2 (MINUTE), 3 (SECOND). - * @param endField The rightmost field which the type comprises of. Valid values: - * 0 (DAY), 1 (HOUR), 2 (MINUTE), 3 (SECOND). + * @param startField + * The leftmost field which the type comprises of. Valid values: 0 (DAY), 1 (HOUR), 2 (MINUTE), + * 3 (SECOND). + * @param endField + * The rightmost field which the type comprises of. Valid values: 0 (DAY), 1 (HOUR), 2 (MINUTE), + * 3 (SECOND). * * @since 3.2.0 */ @Unstable case class DayTimeIntervalType(startField: Byte, endField: Byte) extends AnsiIntervalType { + /** - * The day-time interval type has constant precision. A value of the type always occupies 8 bytes. - * The DAY field is constrained by the upper bound 106751991 to fit to `Long`. + * The day-time interval type has constant precision. A value of the type always occupies 8 + * bytes. The DAY field is constrained by the upper bound 106751991 to fit to `Long`. */ override def defaultSize: Int = 8 diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 6de8570b1422f..bd94c386ab533 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -31,9 +31,9 @@ import org.apache.spark.unsafe.types.UTF8String * A mutable implementation of BigDecimal that can hold a Long if values are small enough. * * The semantics of the fields are as follows: - * - _precision and _scale represent the SQL precision and scale we are looking for - * - If decimalVal is set, it represents the whole decimal value - * - Otherwise, the decimal value is longVal / (10 ** _scale) + * - _precision and _scale represent the SQL precision and scale we are looking for + * - If decimalVal is set, it represents the whole decimal value + * - Otherwise, the decimal value is longVal / (10 ** _scale) * * Note, for values between -1.0 and 1.0, precision digits are only counted after dot. */ @@ -88,22 +88,22 @@ final class Decimal extends Ordered[Decimal] with Serializable { } /** - * Set this Decimal to the given unscaled Long, with a given precision and scale, - * and return it, or return null if it cannot be set due to overflow. + * Set this Decimal to the given unscaled Long, with a given precision and scale, and return it, + * or return null if it cannot be set due to overflow. */ def setOrNull(unscaled: Long, precision: Int, scale: Int): Decimal = { DecimalType.checkNegativeScale(scale) if (unscaled <= -POW_10(MAX_LONG_DIGITS) || unscaled >= POW_10(MAX_LONG_DIGITS)) { // We can't represent this compactly as a long without risking overflow if (precision < 19) { - return null // Requested precision is too low to represent this value + return null // Requested precision is too low to represent this value } this.decimalVal = BigDecimal(unscaled, scale) this.longVal = 0L } else { val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) if (unscaled <= -p || unscaled >= p) { - return null // Requested precision is too low to represent this value + return null // Requested precision is too low to represent this value } this.decimalVal = null this.longVal = unscaled @@ -126,7 +126,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { "roundedValue" -> decimalVal.toString, "originalValue" -> decimal.toString, "precision" -> precision.toString, - "scale" -> scale.toString), Array.empty) + "scale" -> scale.toString), + Array.empty) } this.longVal = 0L this._precision = precision @@ -160,8 +161,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { } /** - * If the value is not in the range of long, convert it to BigDecimal and - * the precision and scale are based on the converted value. + * If the value is not in the range of long, convert it to BigDecimal and the precision and + * scale are based on the converted value. * * This code avoids BigDecimal object allocation as possible to improve runtime efficiency */ @@ -262,37 +263,47 @@ final class Decimal extends Ordered[Decimal] with Serializable { def toByte: Byte = toLong.toByte /** - * @return the Byte value that is equal to the rounded decimal. - * @throws ArithmeticException if the decimal is too big to fit in Byte type. + * @return + * the Byte value that is equal to the rounded decimal. + * @throws ArithmeticException + * if the decimal is too big to fit in Byte type. */ private[sql] def roundToByte(): Byte = - roundToNumeric[Byte](ByteType, Byte.MaxValue, Byte.MinValue) (_.toByte) (_.toByte) + roundToNumeric[Byte](ByteType, Byte.MaxValue, Byte.MinValue)(_.toByte)(_.toByte) /** - * @return the Short value that is equal to the rounded decimal. - * @throws ArithmeticException if the decimal is too big to fit in Short type. + * @return + * the Short value that is equal to the rounded decimal. + * @throws ArithmeticException + * if the decimal is too big to fit in Short type. */ private[sql] def roundToShort(): Short = - roundToNumeric[Short](ShortType, Short.MaxValue, Short.MinValue) (_.toShort) (_.toShort) + roundToNumeric[Short](ShortType, Short.MaxValue, Short.MinValue)(_.toShort)(_.toShort) /** - * @return the Int value that is equal to the rounded decimal. - * @throws ArithmeticException if the decimal too big to fit in Int type. + * @return + * the Int value that is equal to the rounded decimal. + * @throws ArithmeticException + * if the decimal too big to fit in Int type. */ private[sql] def roundToInt(): Int = - roundToNumeric[Int](IntegerType, Int.MaxValue, Int.MinValue) (_.toInt) (_.toInt) + roundToNumeric[Int](IntegerType, Int.MaxValue, Int.MinValue)(_.toInt)(_.toInt) private def toSqlValue: String = this.toString + "BD" - private def roundToNumeric[T <: AnyVal](integralType: IntegralType, maxValue: Int, minValue: Int) - (f1: Long => T) (f2: Double => T): T = { + private def roundToNumeric[T <: AnyVal]( + integralType: IntegralType, + maxValue: Int, + minValue: Int)(f1: Long => T)(f2: Double => T): T = { if (decimalVal.eq(null)) { val numericVal = f1(actualLongVal) if (actualLongVal == numericVal) { numericVal } else { throw DataTypeErrors.castingCauseOverflowError( - toSqlValue, DecimalType(this.precision, this.scale), integralType) + toSqlValue, + DecimalType(this.precision, this.scale), + integralType) } } else { val doubleVal = decimalVal.toDouble @@ -300,14 +311,18 @@ final class Decimal extends Ordered[Decimal] with Serializable { f2(doubleVal) } else { throw DataTypeErrors.castingCauseOverflowError( - toSqlValue, DecimalType(this.precision, this.scale), integralType) + toSqlValue, + DecimalType(this.precision, this.scale), + integralType) } } } /** - * @return the Long value that is equal to the rounded decimal. - * @throws ArithmeticException if the decimal too big to fit in Long type. + * @return + * the Long value that is equal to the rounded decimal. + * @throws ArithmeticException + * if the decimal too big to fit in Long type. */ private[sql] def roundToLong(): Long = { if (decimalVal.eq(null)) { @@ -321,7 +336,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { } catch { case _: ArithmeticException => throw DataTypeErrors.castingCauseOverflowError( - toSqlValue, DecimalType(this.precision, this.scale), LongType) + toSqlValue, + DecimalType(this.precision, this.scale), + LongType) } } } @@ -329,7 +346,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { /** * Update precision and scale while keeping our value the same, and return true if successful. * - * @return true if successful, false if overflow would occur + * @return + * true if successful, false if overflow would occur */ def changePrecision(precision: Int, scale: Int): Boolean = { changePrecision(precision, scale, ROUND_HALF_UP) @@ -338,8 +356,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { /** * Create new `Decimal` with given precision and scale. * - * @return a non-null `Decimal` value if successful. Otherwise, if `nullOnOverflow` is true, null - * is returned; if `nullOnOverflow` is false, an `ArithmeticException` is thrown. + * @return + * a non-null `Decimal` value if successful. Otherwise, if `nullOnOverflow` is true, null is + * returned; if `nullOnOverflow` is false, an `ArithmeticException` is thrown. */ private[sql] def toPrecision( precision: Int, @@ -354,8 +373,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (nullOnOverflow) { null } else { - throw DataTypeErrors.cannotChangeDecimalPrecisionError( - this, precision, scale, context) + throw DataTypeErrors.cannotChangeDecimalPrecisionError(this, precision, scale, context) } } } @@ -363,7 +381,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { /** * Update precision and scale while keeping our value the same, and return true if successful. * - * @return true if successful, false if overflow would occur + * @return + * true if successful, false if overflow would occur */ private[sql] def changePrecision( precision: Int, @@ -482,7 +501,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { // ------------------------------------------------------------------------ // e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) // e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) - def + (that: Decimal): Decimal = { + def +(that: Decimal): Decimal = { if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) { Decimal(longVal + that.longVal, Math.max(precision, that.precision) + 1, scale) } else { @@ -490,7 +509,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } - def - (that: Decimal): Decimal = { + def -(that: Decimal): Decimal = { if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) { Decimal(longVal - that.longVal, Math.max(precision, that.precision) + 1, scale) } else { @@ -499,14 +518,19 @@ final class Decimal extends Ordered[Decimal] with Serializable { } // TypeCoercion will take care of the precision, scale of result - def * (that: Decimal): Decimal = + def *(that: Decimal): Decimal = Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, MATH_CONTEXT)) - def / (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, - DecimalType.MAX_SCALE + 1, MATH_CONTEXT.getRoundingMode)) + def /(that: Decimal): Decimal = + if (that.isZero) { + null + } else { + Decimal( + toJavaBigDecimal + .divide(that.toJavaBigDecimal, DecimalType.MAX_SCALE + 1, MATH_CONTEXT.getRoundingMode)) + } - def % (that: Decimal): Decimal = + def %(that: Decimal): Decimal = if (that.isZero) null else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT)) @@ -526,12 +550,14 @@ final class Decimal extends Ordered[Decimal] with Serializable { def abs: Decimal = if (this < Decimal.ZERO) this.unary_- else this - def floor: Decimal = if (scale == 0) this else { + def floor: Decimal = if (scale == 0) this + else { val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision toPrecision(newPrecision, 0, ROUND_FLOOR, nullOnOverflow = false) } - def ceil: Decimal = if (scale == 0) this else { + def ceil: Decimal = if (scale == 0) this + else { val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision toPrecision(newPrecision, 0, ROUND_CEILING, nullOnOverflow = false) } @@ -612,7 +638,7 @@ object Decimal { // We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow. // For example: Decimal("6.0790316E+25569151") if (numDigitsInIntegralPart(bigDecimal) > DecimalType.MAX_PRECISION && - !SqlApiConf.get.allowNegativeScaleOfDecimalEnabled) { + !SqlApiConf.get.allowNegativeScaleOfDecimalEnabled) { null } else { Decimal(bigDecimal) @@ -632,7 +658,7 @@ object Decimal { // We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow. // For example: Decimal("6.0790316E+25569151") if (numDigitsInIntegralPart(bigDecimal) > DecimalType.MAX_PRECISION && - !SqlApiConf.get.allowNegativeScaleOfDecimalEnabled) { + !SqlApiConf.get.allowNegativeScaleOfDecimalEnabled) { throw DataTypeErrors.outOfDecimalTypeRangeError(str) } else { Decimal(bigDecimal) @@ -657,16 +683,18 @@ object Decimal { // Max precision of a decimal value stored in `numBytes` bytes def maxPrecisionForBytes(numBytes: Int): Int = { - Math.round( // convert double to long - Math.floor(Math.log10( // number of base-10 digits - Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes + Math + .round( // convert double to long + Math.floor(Math.log10( // number of base-10 digits + Math.pow(2, 8 * numBytes - 1) - 1)) + ) // max value stored in numBytes .asInstanceOf[Int] } // Returns the minimum number of bytes needed to store a decimal with a given `precision`. lazy val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision) - private def computeMinBytesForPrecision(precision : Int) : Int = { + private def computeMinBytesForPrecision(precision: Int): Int = { var numBytes = 1 while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { numBytes += 1 diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 9de34d0b3bc16..bff483cefda91 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -26,9 +26,8 @@ import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.internal.SqlApiConf /** - * The data type representing `java.math.BigDecimal` values. - * A Decimal that must have fixed precision (the maximum number of digits) and scale (the number - * of digits on right side of dot). + * The data type representing `java.math.BigDecimal` values. A Decimal that must have fixed + * precision (the maximum number of digits) and scale (the number of digits on right side of dot). * * The precision can be up to 38, scale can also be up to 38 (less or equal to precision). * @@ -49,7 +48,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { if (precision > DecimalType.MAX_PRECISION) { throw DataTypeErrors.decimalPrecisionExceedsMaxPrecisionError( - precision, DecimalType.MAX_PRECISION) + precision, + DecimalType.MAX_PRECISION) } // default constructor for Java @@ -63,8 +63,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { override def sql: String = typeName.toUpperCase(Locale.ROOT) /** - * Returns whether this DecimalType is wider than `other`. If yes, it means `other` - * can be casted into `this` safely without losing any precision or range. + * Returns whether this DecimalType is wider than `other`. If yes, it means `other` can be + * casted into `this` safely without losing any precision or range. */ private[sql] def isWiderThan(other: DataType): Boolean = isWiderThanInternal(other) @@ -78,8 +78,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { } /** - * Returns whether this DecimalType is tighter than `other`. If yes, it means `this` - * can be casted into `other` safely without losing any precision or range. + * Returns whether this DecimalType is tighter than `other`. If yes, it means `this` can be + * casted into `other` safely without losing any precision or range. */ private[sql] def isTighterThan(other: DataType): Boolean = other match { case dt: DecimalType => @@ -94,8 +94,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { } /** - * The default size of a value of the DecimalType is 8 bytes when precision is at most 18, - * and 16 bytes otherwise. + * The default size of a value of the DecimalType is 8 bytes when precision is at most 18, and + * 16 bytes otherwise. */ override def defaultSize: Int = if (precision <= Decimal.MAX_LONG_DIGITS) 8 else 16 @@ -104,7 +104,6 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { private[spark] override def asNullable: DecimalType = this } - /** * Extra factory methods and pattern matchers for Decimals. * @@ -167,10 +166,11 @@ object DecimalType extends AbstractDataType { /** * Scale adjustment implementation is based on Hive's one, which is itself inspired to * SQLServer's one. In particular, when a result precision is greater than - * {@link #MAX_PRECISION}, the corresponding scale is reduced to prevent the integral part of a + * {@link #MAX_PRECISION} , the corresponding scale is reduced to prevent the integral part of a * result from being truncated. * - * This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to true. + * This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to + * true. */ private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = { // Assumptions: diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index bc0ed725cf266..873f0c237c6c4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -27,7 +27,8 @@ import org.apache.spark.annotation.Stable * @since 1.3.0 */ @Stable -class DoubleType private() extends FractionalType { +class DoubleType private () extends FractionalType { + /** * The default size of a value of the DoubleType is 8 bytes. */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/FloatType.scala index 8b54f830d48a3..df4b03cd42bd4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -27,7 +27,8 @@ import org.apache.spark.annotation.Stable * @since 1.3.0 */ @Stable -class FloatType private() extends FractionalType { +class FloatType private () extends FractionalType { + /** * The default size of a value of the FloatType is 4 bytes. */ @@ -36,7 +37,6 @@ class FloatType private() extends FractionalType { private[spark] override def asNullable: FloatType = this } - /** * @since 1.3.0 */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/IntegerType.scala index b26a555c9b572..dc4727cb1215b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/IntegerType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/IntegerType.scala @@ -25,7 +25,8 @@ import org.apache.spark.annotation.Stable * @since 1.3.0 */ @Stable -class IntegerType private() extends IntegralType { +class IntegerType private () extends IntegralType { + /** * The default size of a value of the IntegerType is 4 bytes. */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/LongType.scala index 87ebacfe9ce88..f65c4c70acd27 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/LongType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/LongType.scala @@ -25,7 +25,8 @@ import org.apache.spark.annotation.Stable * @since 1.3.0 */ @Stable -class LongType private() extends IntegralType { +class LongType private () extends IntegralType { + /** * The default size of a value of the LongType is 8 bytes. */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala index dba870466fc1c..1dfb9aaf9e29b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -28,15 +28,16 @@ import org.apache.spark.sql.catalyst.util.StringConcat * * Please use `DataTypes.createMapType()` to create a specific instance. * - * @param keyType The data type of map keys. - * @param valueType The data type of map values. - * @param valueContainsNull Indicates if map values have `null` values. + * @param keyType + * The data type of map keys. + * @param valueType + * The data type of map values. + * @param valueContainsNull + * Indicates if map values have `null` values. */ @Stable -case class MapType( - keyType: DataType, - valueType: DataType, - valueContainsNull: Boolean) extends DataType { +case class MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) + extends DataType { /** No-arg constructor for kryo. */ def this() = this(null, null, false) @@ -48,8 +49,9 @@ case class MapType( if (maxDepth > 0) { stringConcat.append(s"$prefix-- key: ${keyType.typeName}\n") DataType.buildFormattedString(keyType, s"$prefix |", stringConcat, maxDepth) - stringConcat.append(s"$prefix-- value: ${valueType.typeName} " + - s"(valueContainsNull = $valueContainsNull)\n") + stringConcat.append( + s"$prefix-- value: ${valueType.typeName} " + + s"(valueContainsNull = $valueContainsNull)\n") DataType.buildFormattedString(valueType, s"$prefix |", stringConcat, maxDepth) } } @@ -61,9 +63,9 @@ case class MapType( ("valueContainsNull" -> valueContainsNull) /** - * The default size of a value of the MapType is - * (the default size of the key type + the default size of the value type). - * We assume that there is only 1 element on average in a map. See SPARK-18853. + * The default size of a value of the MapType is (the default size of the key type + the default + * size of the value type). We assume that there is only 1 element on average in a map. See + * SPARK-18853. */ override def defaultSize: Int = 1 * (keyType.defaultSize + valueType.defaultSize) @@ -77,8 +79,8 @@ case class MapType( MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) /** - * Returns the same data type but set all nullability fields are true - * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). + * Returns the same data type but set all nullability fields are true (`StructField.nullable`, + * `ArrayType.containsNull`, and `MapType.valueContainsNull`). * * @since 4.0.0 */ @@ -104,8 +106,8 @@ object MapType extends AbstractDataType { override private[sql] def simpleString: String = "map" /** - * Construct a [[MapType]] object with the given key type and value type. - * The `valueContainsNull` is true. + * Construct a [[MapType]] object with the given key type and value type. The + * `valueContainsNull` is true. */ def apply(keyType: DataType, valueType: DataType): MapType = MapType(keyType: DataType, valueType: DataType, valueContainsNull = true) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 70e03905d4b05..05f91b5ba2313 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -26,7 +26,6 @@ import org.apache.spark.annotation.Stable import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.util.ArrayImplicits._ - /** * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean, * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and @@ -35,13 +34,14 @@ import org.apache.spark.util.ArrayImplicits._ * The default constructor is private. User should use either [[MetadataBuilder]] or * `Metadata.fromJson()` to create Metadata instances. * - * @param map an immutable map that stores the data + * @param map + * an immutable map that stores the data * * @since 1.3.0 */ @Stable sealed class Metadata private[types] (private[types] val map: Map[String, Any]) - extends Serializable { + extends Serializable { /** No-arg constructor for kryo. */ protected def this() = this(null) @@ -173,7 +173,8 @@ object Metadata { builder.putStringArray(key, value.asInstanceOf[List[JString]].map(_.s).toArray) case _: JObject => builder.putMetadataArray( - key, value.asInstanceOf[List[JObject]].map(fromJObject).toArray) + key, + value.asInstanceOf[List[JObject]].map(fromJObject).toArray) case other => throw DataTypeErrors.unsupportedArrayTypeError(other.getClass) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/NullType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/NullType.scala index d211fac70c641..4e7fd3a00a8af 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/NullType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/NullType.scala @@ -25,7 +25,7 @@ import org.apache.spark.annotation.Stable * @since 1.3.0 */ @Stable -class NullType private() extends DataType { +class NullType private () extends DataType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "NullType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ShortType.scala index 66696793e6279..c3b6bc75facd3 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/ShortType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/ShortType.scala @@ -25,7 +25,8 @@ import org.apache.spark.annotation.Stable * @since 1.3.0 */ @Stable -class ShortType private() extends IntegralType { +class ShortType private () extends IntegralType { + /** * The default size of a value of the ShortType is 2 bytes. */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index df77f091f41f4..c2dd6cec7ba74 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -26,15 +26,17 @@ import org.apache.spark.sql.catalyst.util.CollationFactory * The data type representing `String` values. Please use the singleton `DataTypes.StringType`. * * @since 1.3.0 - * @param collationId The id of collation for this StringType. + * @param collationId + * The id of collation for this StringType. */ @Stable -class StringType private(val collationId: Int) extends AtomicType with Serializable { +class StringType private (val collationId: Int) extends AtomicType with Serializable { + /** - * Support for Binary Equality implies that strings are considered equal only if - * they are byte for byte equal. E.g. all accent or case-insensitive collations are considered - * non-binary. If this field is true, byte level operations can be used against this datatype - * (e.g. for equality and hashing). + * Support for Binary Equality implies that strings are considered equal only if they are byte + * for byte equal. E.g. all accent or case-insensitive collations are considered non-binary. If + * this field is true, byte level operations can be used against this datatype (e.g. for + * equality and hashing). */ private[sql] def supportsBinaryEquality: Boolean = CollationFactory.fetchCollation(collationId).supportsBinaryEquality @@ -42,6 +44,9 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa private[sql] def supportsLowercaseEquality: Boolean = CollationFactory.fetchCollation(collationId).supportsLowercaseEquality + private[sql] def isNonCSAI: Boolean = + !CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId) + private[sql] def isUTF8BinaryCollation: Boolean = collationId == CollationFactory.UTF8_BINARY_COLLATION_ID @@ -49,18 +54,18 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa collationId == CollationFactory.UTF8_LCASE_COLLATION_ID /** - * Support for Binary Ordering implies that strings are considered equal only - * if they are byte for byte equal. E.g. all accent or case-insensitive collations are - * considered non-binary. Also their ordering does not require calls to ICU library, as - * it follows spark internal implementation. If this field is true, byte level operations - * can be used against this datatype (e.g. for equality, hashing and ordering). + * Support for Binary Ordering implies that strings are considered equal only if they are byte + * for byte equal. E.g. all accent or case-insensitive collations are considered non-binary. + * Also their ordering does not require calls to ICU library, as it follows spark internal + * implementation. If this field is true, byte level operations can be used against this + * datatype (e.g. for equality, hashing and ordering). */ private[sql] def supportsBinaryOrdering: Boolean = CollationFactory.fetchCollation(collationId).supportsBinaryOrdering /** - * Type name that is shown to the customer. - * If this is an UTF8_BINARY collation output is `string` due to backwards compatibility. + * Type name that is shown to the customer. If this is an UTF8_BINARY collation output is + * `string` due to backwards compatibility. */ override def typeName: String = if (isUTF8BinaryCollation) "string" diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala index 3ff96fea9ee04..d4e590629921c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -31,11 +31,15 @@ import org.apache.spark.util.SparkSchemaUtils /** * A field inside a StructType. - * @param name The name of this field. - * @param dataType The data type of this field. - * @param nullable Indicates if values of this field can be `null` values. - * @param metadata The metadata of this field. The metadata should be preserved during - * transformation if the content of the column is not modified, e.g, in selection. + * @param name + * The name of this field. + * @param dataType + * The data type of this field. + * @param nullable + * Indicates if values of this field can be `null` values. + * @param metadata + * The metadata of this field. The metadata should be preserved during transformation if the + * content of the column is not modified, e.g, in selection. * * @since 1.3.0 */ @@ -54,8 +58,9 @@ case class StructField( stringConcat: StringConcat, maxDepth: Int): Unit = { if (maxDepth > 0) { - stringConcat.append(s"$prefix-- ${SparkSchemaUtils.escapeMetaCharacters(name)}: " + - s"${dataType.typeName} (nullable = $nullable)\n") + stringConcat.append( + s"$prefix-- ${SparkSchemaUtils.escapeMetaCharacters(name)}: " + + s"${dataType.typeName} (nullable = $nullable)\n") DataType.buildFormattedString(dataType, s"$prefix |", stringConcat, maxDepth) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala index 3e637b5110122..4ef1cf400b80e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -41,10 +41,10 @@ import org.apache.spark.util.SparkCollectionUtils * {{{ * StructType(fields: Seq[StructField]) * }}} - * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names. - * If multiple [[StructField]]s are extracted, a [[StructType]] object will be returned. - * If a provided name does not have a matching field, it will be ignored. For the case - * of extracting a single [[StructField]], a `null` will be returned. + * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names. If + * multiple [[StructField]]s are extracted, a [[StructType]] object will be returned. If a + * provided name does not have a matching field, it will be ignored. For the case of extracting a + * single [[StructField]], a `null` will be returned. * * Scala Example: * {{{ @@ -126,8 +126,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override def equals(that: Any): Boolean = { that match { case StructType(otherFields) => - java.util.Arrays.equals( - fields.asInstanceOf[Array[AnyRef]], otherFields.asInstanceOf[Array[AnyRef]]) + java.util.Arrays + .equals(fields.asInstanceOf[Array[AnyRef]], otherFields.asInstanceOf[Array[AnyRef]]) case _ => false } } @@ -146,7 +146,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * .add(StructField("a", IntegerType, true)) * .add(StructField("b", LongType, false)) * .add(StructField("c", StringType, true)) - *}}} + * }}} */ def add(field: StructField): StructType = { StructType(fields :+ field) @@ -155,10 +155,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru /** * Creates a new [[StructType]] by adding a new nullable field with no metadata. * - * val struct = (new StructType) - * .add("a", IntegerType) - * .add("b", LongType) - * .add("c", StringType) + * val struct = (new StructType) .add("a", IntegerType) .add("b", LongType) .add("c", + * StringType) */ def add(name: String, dataType: DataType): StructType = { StructType(fields :+ StructField(name, dataType, nullable = true, Metadata.empty)) @@ -167,10 +165,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru /** * Creates a new [[StructType]] by adding a new field with no metadata. * - * val struct = (new StructType) - * .add("a", IntegerType, true) - * .add("b", LongType, false) - * .add("c", StringType, true) + * val struct = (new StructType) .add("a", IntegerType, true) .add("b", LongType, false) + * .add("c", StringType, true) */ def add(name: String, dataType: DataType, nullable: Boolean): StructType = { StructType(fields :+ StructField(name, dataType, nullable, Metadata.empty)) @@ -185,11 +181,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * .add("c", StringType, true, Metadata.empty) * }}} */ - def add( - name: String, - dataType: DataType, - nullable: Boolean, - metadata: Metadata): StructType = { + def add(name: String, dataType: DataType, nullable: Boolean, metadata: Metadata): StructType = { StructType(fields :+ StructField(name, dataType, nullable, metadata)) } @@ -202,11 +194,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * .add("c", StringType, true, "comment3") * }}} */ - def add( - name: String, - dataType: DataType, - nullable: Boolean, - comment: String): StructType = { + def add(name: String, dataType: DataType, nullable: Boolean, comment: String): StructType = { StructType(fields :+ StructField(name, dataType, nullable).withComment(comment)) } @@ -226,8 +214,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru } /** - * Creates a new [[StructType]] by adding a new field with no metadata where the - * dataType is specified as a String. + * Creates a new [[StructType]] by adding a new field with no metadata where the dataType is + * specified as a String. * * {{{ * val struct = (new StructType) @@ -241,8 +229,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru } /** - * Creates a new [[StructType]] by adding a new field and specifying metadata where the - * dataType is specified as a String. + * Creates a new [[StructType]] by adding a new field and specifying metadata where the dataType + * is specified as a String. * {{{ * val struct = (new StructType) * .add("a", "int", true, Metadata.empty) @@ -250,17 +238,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * .add("c", "string", true, Metadata.empty) * }}} */ - def add( - name: String, - dataType: String, - nullable: Boolean, - metadata: Metadata): StructType = { + def add(name: String, dataType: String, nullable: Boolean, metadata: Metadata): StructType = { add(name, DataTypeParser.parseDataType(dataType), nullable, metadata) } /** - * Creates a new [[StructType]] by adding a new field and specifying metadata where the - * dataType is specified as a String. + * Creates a new [[StructType]] by adding a new field and specifying metadata where the dataType + * is specified as a String. * {{{ * val struct = (new StructType) * .add("a", "int", true, "comment1") @@ -268,21 +252,19 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * .add("c", "string", true, "comment3") * }}} */ - def add( - name: String, - dataType: String, - nullable: Boolean, - comment: String): StructType = { + def add(name: String, dataType: String, nullable: Boolean, comment: String): StructType = { add(name, DataTypeParser.parseDataType(dataType), nullable, comment) } /** * Extracts the [[StructField]] with the given name. * - * @throws IllegalArgumentException if a field with the given name does not exist + * @throws IllegalArgumentException + * if a field with the given name does not exist */ def apply(name: String): StructField = { - nameToField.getOrElse(name, + nameToField.getOrElse( + name, throw new SparkIllegalArgumentException( errorClass = "FIELD_NOT_FOUND", messageParameters = immutable.Map( @@ -294,7 +276,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * Returns a [[StructType]] containing [[StructField]]s of the given names, preserving the * original order of fields. * - * @throws IllegalArgumentException if at least one given field name does not exist + * @throws IllegalArgumentException + * if at least one given field name does not exist */ def apply(names: Set[String]): StructType = { val nonExistFields = names -- fieldNamesSet @@ -312,10 +295,12 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru /** * Returns the index of a given field. * - * @throws IllegalArgumentException if a field with the given name does not exist + * @throws IllegalArgumentException + * if a field with the given name does not exist */ def fieldIndex(name: String): Int = { - nameToIndex.getOrElse(name, + nameToIndex.getOrElse( + name, throw new SparkIllegalArgumentException( errorClass = "FIELD_NOT_FOUND", messageParameters = immutable.Map( @@ -354,10 +339,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru } else if (found.isEmpty) { None } else { - findField( - parent = found.head, - searchPath = searchPath.tail, - normalizedPath) + findField(parent = found.head, searchPath = searchPath.tail, normalizedPath) } } @@ -433,11 +415,14 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum override def simpleString: String = { - val fieldTypes = fields.to(LazyList) + val fieldTypes = fields + .to(LazyList) .map(field => s"${field.name}:${field.dataType.simpleString}") SparkStringUtils.truncatedString( fieldTypes, - "struct<", ",", ">", + "struct<", + ",", + ">", SqlApiConf.get.maxToStringFields) } @@ -460,9 +445,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru /** * Returns a string containing a schema in DDL format. For example, the following value: - * `StructType(Seq(StructField("eventId", IntegerType), StructField("s", StringType)))` - * will be converted to `eventId` INT, `s` STRING. - * The returned DDL schema can be used in a table creation. + * `StructType(Seq(StructField("eventId", IntegerType), StructField("s", StringType)))` will be + * converted to `eventId` INT, `s` STRING. The returned DDL schema can be used in a table + * creation. * * @since 2.4.0 */ @@ -470,8 +455,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private[sql] override def simpleString(maxNumberFields: Int): String = { val builder = new StringBuilder - val fieldTypes = fields.take(maxNumberFields).map { - f => s"${f.name}: ${f.dataType.simpleString(maxNumberFields)}" + val fieldTypes = fields.take(maxNumberFields).map { f => + s"${f.name}: ${f.dataType.simpleString(maxNumberFields)}" } builder.append("struct<") builder.append(fieldTypes.mkString(", ")) @@ -486,31 +471,29 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru } /** - * Merges with another schema (`StructType`). For a struct field A from `this` and a struct field - * B from `that`, + * Merges with another schema (`StructType`). For a struct field A from `this` and a struct + * field B from `that`, * - * 1. If A and B have the same name and data type, they are merged to a field C with the same name - * and data type. C is nullable if and only if either A or B is nullable. - * 2. If A doesn't exist in `that`, it's included in the result schema. - * 3. If B doesn't exist in `this`, it's also included in the result schema. - * 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be - * thrown. + * 1. If A and B have the same name and data type, they are merged to a field C with the same + * name and data type. C is nullable if and only if either A or B is nullable. 2. If A + * doesn't exist in `that`, it's included in the result schema. 3. If B doesn't exist in + * `this`, it's also included in the result schema. 4. Otherwise, `this` and `that` are + * considered as conflicting schemas and an exception would be thrown. */ private[sql] def merge(that: StructType, caseSensitive: Boolean = true): StructType = StructType.merge(this, that, caseSensitive).asInstanceOf[StructType] override private[spark] def asNullable: StructType = { - val newFields = fields.map { - case StructField(name, dataType, nullable, metadata) => - StructField(name, dataType.asNullable, nullable = true, metadata) + val newFields = fields.map { case StructField(name, dataType, nullable, metadata) => + StructField(name, dataType.asNullable, nullable = true, metadata) } StructType(newFields) } /** - * Returns the same data type but set all nullability fields are true - * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). + * Returns the same data type but set all nullability fields are true (`StructField.nullable`, + * `ArrayType.containsNull`, and `MapType.valueContainsNull`). * * @since 4.0.0 */ @@ -562,7 +545,8 @@ object StructType extends AbstractDataType { case StructType(fields) => val newFields = fields.map { f => val mb = new MetadataBuilder() - f.copy(dataType = removeMetadata(key, f.dataType), + f.copy( + dataType = removeMetadata(key, f.dataType), metadata = mb.withMetadata(f.metadata).remove(key).build()) } StructType(newFields) @@ -570,39 +554,49 @@ object StructType extends AbstractDataType { } /** - * This leverages `merge` to merge data types for UNION operator by specializing - * the handling of struct types to follow UNION semantics. + * This leverages `merge` to merge data types for UNION operator by specializing the handling of + * struct types to follow UNION semantics. */ private[sql] def unionLikeMerge(left: DataType, right: DataType): DataType = - mergeInternal(left, right, (s1: StructType, s2: StructType) => { - val leftFields = s1.fields - val rightFields = s2.fields - require(leftFields.length == rightFields.length, "To merge nullability, " + - "two structs must have same number of fields.") - - val newFields = leftFields.zip(rightFields).map { - case (leftField, rightField) => + mergeInternal( + left, + right, + (s1: StructType, s2: StructType) => { + val leftFields = s1.fields + val rightFields = s2.fields + require( + leftFields.length == rightFields.length, + "To merge nullability, " + + "two structs must have same number of fields.") + + val newFields = leftFields.zip(rightFields).map { case (leftField, rightField) => leftField.copy( dataType = unionLikeMerge(leftField.dataType, rightField.dataType), nullable = leftField.nullable || rightField.nullable) - } - StructType(newFields) - }) - - private[sql] def merge(left: DataType, right: DataType, caseSensitive: Boolean = true): DataType = - mergeInternal(left, right, (s1: StructType, s2: StructType) => { - val leftFields = s1.fields - val rightFields = s2.fields - val newFields = mutable.ArrayBuffer.empty[StructField] + } + StructType(newFields) + }) - def normalize(name: String): String = { - if (caseSensitive) name else name.toLowerCase(Locale.ROOT) - } + private[sql] def merge( + left: DataType, + right: DataType, + caseSensitive: Boolean = true): DataType = + mergeInternal( + left, + right, + (s1: StructType, s2: StructType) => { + val leftFields = s1.fields + val rightFields = s2.fields + val newFields = mutable.ArrayBuffer.empty[StructField] + + def normalize(name: String): String = { + if (caseSensitive) name else name.toLowerCase(Locale.ROOT) + } - val rightMapped = fieldsMap(rightFields, caseSensitive) - leftFields.foreach { - case leftField @ StructField(leftName, leftType, leftNullable, _) => - rightMapped.get(normalize(leftName)) + val rightMapped = fieldsMap(rightFields, caseSensitive) + leftFields.foreach { case leftField @ StructField(leftName, leftType, leftNullable, _) => + rightMapped + .get(normalize(leftName)) .map { case rightField @ StructField(rightName, rightType, rightNullable, _) => try { leftField.copy( @@ -610,39 +604,40 @@ object StructType extends AbstractDataType { nullable = leftNullable || rightNullable) } catch { case NonFatal(e) => - throw DataTypeErrors.cannotMergeIncompatibleDataTypesError( - leftType, rightType) + throw DataTypeErrors.cannotMergeIncompatibleDataTypesError(leftType, rightType) } } .orElse { Some(leftField) } .foreach(newFields += _) - } - - val leftMapped = fieldsMap(leftFields, caseSensitive) - rightFields - .filterNot(f => leftMapped.contains(normalize(f.name))) - .foreach { f => - newFields += f } - StructType(newFields.toArray) - }) + val leftMapped = fieldsMap(leftFields, caseSensitive) + rightFields + .filterNot(f => leftMapped.contains(normalize(f.name))) + .foreach { f => + newFields += f + } + + StructType(newFields.toArray) + }) private def mergeInternal( left: DataType, right: DataType, mergeStruct: (StructType, StructType) => StructType): DataType = (left, right) match { - case (ArrayType(leftElementType, leftContainsNull), - ArrayType(rightElementType, rightContainsNull)) => + case ( + ArrayType(leftElementType, leftContainsNull), + ArrayType(rightElementType, rightContainsNull)) => ArrayType( mergeInternal(leftElementType, rightElementType, mergeStruct), leftContainsNull || rightContainsNull) - case (MapType(leftKeyType, leftValueType, leftContainsNull), - MapType(rightKeyType, rightValueType, rightContainsNull)) => + case ( + MapType(leftKeyType, leftValueType, leftContainsNull), + MapType(rightKeyType, rightValueType, rightContainsNull)) => MapType( mergeInternal(leftKeyType, rightKeyType, mergeStruct), mergeInternal(leftValueType, rightValueType, mergeStruct), @@ -650,17 +645,20 @@ object StructType extends AbstractDataType { case (s1: StructType, s2: StructType) => mergeStruct(s1, s2) - case (DecimalType.Fixed(leftPrecision, leftScale), - DecimalType.Fixed(rightPrecision, rightScale)) => + case ( + DecimalType.Fixed(leftPrecision, leftScale), + DecimalType.Fixed(rightPrecision, rightScale)) => if (leftScale == rightScale) { DecimalType(leftPrecision.max(rightPrecision), leftScale) } else { throw DataTypeErrors.cannotMergeDecimalTypesWithIncompatibleScaleError( - leftScale, rightScale) + leftScale, + rightScale) } case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) - if leftUdt.userClass == rightUdt.userClass => leftUdt + if leftUdt.userClass == rightUdt.userClass => + leftUdt case (YearMonthIntervalType(lstart, lend), YearMonthIntervalType(rstart, rend)) => YearMonthIntervalType(Math.min(lstart, rstart).toByte, Math.max(lend, rend).toByte) @@ -706,10 +704,12 @@ object StructType extends AbstractDataType { // Found a missing field in `source`. newFields += field } else if (bothStructType(found.get.dataType, field.dataType) && - !found.get.dataType.sameType(field.dataType)) { + !found.get.dataType.sameType(field.dataType)) { // Found a field with same name, but different data type. - findMissingFields(found.get.dataType.asInstanceOf[StructType], - field.dataType.asInstanceOf[StructType], resolver).map { missingType => + findMissingFields( + found.get.dataType.asInstanceOf[StructType], + field.dataType.asInstanceOf[StructType], + resolver).map { missingType => newFields += found.get.copy(dataType = missingType) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/TimestampNTZType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/TimestampNTZType.scala index 9968d75dd2577..b08d16f0e2c97 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/TimestampNTZType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/TimestampNTZType.scala @@ -20,16 +20,17 @@ package org.apache.spark.sql.types import org.apache.spark.annotation.Unstable /** - * The timestamp without time zone type represents a local time in microsecond precision, - * which is independent of time zone. - * Its valid range is [0001-01-01T00:00:00.000000, 9999-12-31T23:59:59.999999]. - * To represent an absolute point in time, use `TimestampType` instead. + * The timestamp without time zone type represents a local time in microsecond precision, which is + * independent of time zone. Its valid range is [0001-01-01T00:00:00.000000, + * 9999-12-31T23:59:59.999999]. To represent an absolute point in time, use `TimestampType` + * instead. * * Please use the singleton `DataTypes.TimestampNTZType` to refer the type. * @since 3.4.0 */ @Unstable -class TimestampNTZType private() extends DatetimeType { +class TimestampNTZType private () extends DatetimeType { + /** * The default size of a value of the TimestampNTZType is 8 bytes. */ @@ -42,9 +43,9 @@ class TimestampNTZType private() extends DatetimeType { /** * The companion case object and its class is separated so the companion object also subclasses - * the TimestampNTZType class. Otherwise, the companion object would be of type - * "TimestampNTZType" in byte code. Defined with a private constructor so the companion - * object is the only possible instantiation. + * the TimestampNTZType class. Otherwise, the companion object would be of type "TimestampNTZType" + * in byte code. Defined with a private constructor so the companion object is the only possible + * instantiation. * * @since 3.4.0 */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/TimestampType.scala index 1185e4a9e32ca..bf869d1f38c57 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/TimestampType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/TimestampType.scala @@ -20,16 +20,16 @@ package org.apache.spark.sql.types import org.apache.spark.annotation.Stable /** - * The timestamp type represents a time instant in microsecond precision. - * Valid range is [0001-01-01T00:00:00.000000Z, 9999-12-31T23:59:59.999999Z] where - * the left/right-bound is a date and time of the proleptic Gregorian - * calendar in UTC+00:00. + * The timestamp type represents a time instant in microsecond precision. Valid range is + * [0001-01-01T00:00:00.000000Z, 9999-12-31T23:59:59.999999Z] where the left/right-bound is a date + * and time of the proleptic Gregorian calendar in UTC+00:00. * * Please use the singleton `DataTypes.TimestampType` to refer the type. * @since 1.3.0 */ @Stable -class TimestampType private() extends DatetimeType { +class TimestampType private () extends DatetimeType { + /** * The default size of a value of the TimestampType is 8 bytes. */ @@ -40,8 +40,8 @@ class TimestampType private() extends DatetimeType { /** * The companion case object and its class is separated so the companion object also subclasses - * the TimestampType class. Otherwise, the companion object would be of type "TimestampType$" - * in byte code. Defined with a private constructor so the companion object is the only possible + * the TimestampType class. Otherwise, the companion object would be of type "TimestampType$" in + * byte code. Defined with a private constructor so the companion object is the only possible * instantiation. * * @since 1.3.0 diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala index 9219c1d139b99..85d421a07577b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala @@ -45,21 +45,26 @@ object UDTRegistration extends Serializable with Logging { /** * Queries if a given user class is already registered or not. - * @param userClassName the name of user class - * @return boolean value indicates if the given user class is registered or not + * @param userClassName + * the name of user class + * @return + * boolean value indicates if the given user class is registered or not */ def exists(userClassName: String): Boolean = udtMap.contains(userClassName) /** - * Registers an UserDefinedType to an user class. If the user class is already registered - * with another UserDefinedType, warning log message will be shown. - * @param userClass the name of user class - * @param udtClass the name of UserDefinedType class for the given userClass + * Registers an UserDefinedType to an user class. If the user class is already registered with + * another UserDefinedType, warning log message will be shown. + * @param userClass + * the name of user class + * @param udtClass + * the name of UserDefinedType class for the given userClass */ def register(userClass: String, udtClass: String): Unit = { if (udtMap.contains(userClass)) { - logWarning(log"Cannot register UDT for ${MDC(LogKeys.CLASS_NAME, userClass)}, " + - log"which is already registered.") + logWarning( + log"Cannot register UDT for ${MDC(LogKeys.CLASS_NAME, userClass)}, " + + log"which is already registered.") } else { // When register UDT with class name, we can't check if the UDT class is an UserDefinedType, // or not. The check is deferred. @@ -69,8 +74,10 @@ object UDTRegistration extends Serializable with Logging { /** * Returns the Class of UserDefinedType for the name of a given user class. - * @param userClass class name of user class - * @return Option value of the Class object of UserDefinedType + * @param userClass + * class name of user class + * @return + * Option value of the Class object of UserDefinedType */ def getUDTFor(userClass: String): Option[Class[_]] = { udtMap.get(userClass).map { udtClassName => diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala index 7ec00bde0b25f..4993e249b3059 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala @@ -22,13 +22,8 @@ package org.apache.spark.sql.types private[sql] object UpCastRule { // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: - val numericPrecedence: IndexedSeq[NumericType] = IndexedSeq( - ByteType, - ShortType, - IntegerType, - LongType, - FloatType, - DoubleType) + val numericPrecedence: IndexedSeq[NumericType] = + IndexedSeq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType) /** * Returns true iff we can safely up-cast the `from` type to `to` type without any truncating or @@ -62,10 +57,9 @@ private[sql] object UpCastRule { case (StructType(fromFields), StructType(toFields)) => fromFields.length == toFields.length && - fromFields.zip(toFields).forall { - case (f1, f2) => - resolvableNullability(f1.nullable, f2.nullable) && canUpCast(f1.dataType, f2.dataType) - } + fromFields.zip(toFields).forall { case (f1, f2) => + resolvableNullability(f1.nullable, f2.nullable) && canUpCast(f1.dataType, f2.dataType) + } case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 5cbd876b31e68..dd8ca26c52462 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -27,15 +27,14 @@ import org.apache.spark.annotation.{DeveloperApi, Since} /** * The data type for User Defined Types (UDTs). * - * This interface allows a user to make their own classes more interoperable with SparkSQL; - * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create - * a `DataFrame` which has class X in the schema. + * This interface allows a user to make their own classes more interoperable with SparkSQL; e.g., + * by creating a [[UserDefinedType]] for a class X, it becomes possible to create a `DataFrame` + * which has class X in the schema. * - * For SparkSQL to recognize UDTs, the UDT must be annotated with - * [[SQLUserDefinedType]]. + * For SparkSQL to recognize UDTs, the UDT must be annotated with [[SQLUserDefinedType]]. * - * The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD. - * The conversion via `deserialize` occurs when reading from a `DataFrame`. + * The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD. The + * conversion via `deserialize` occurs when reading from a `DataFrame`. */ @DeveloperApi @Since("3.2.0") @@ -81,7 +80,7 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa override private[sql] def acceptsType(dataType: DataType): Boolean = dataType match { case other: UserDefinedType[_] if this.userClass != null && other.userClass != null => this.getClass == other.getClass || - this.userClass.isAssignableFrom(other.userClass) + this.userClass.isAssignableFrom(other.userClass) case _ => false } @@ -98,6 +97,7 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa } private[spark] object UserDefinedType { + /** * Get the sqlType of a (potential) [[UserDefinedType]]. */ @@ -115,7 +115,8 @@ private[spark] object UserDefinedType { private[sql] class PythonUserDefinedType( val sqlType: DataType, override val pyUDT: String, - override val serializedPyClass: String) extends UserDefinedType[Any] { + override val serializedPyClass: String) + extends UserDefinedType[Any] { /* The serialization is handled by UDT class in Python */ override def serialize(obj: Any): Any = obj diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala index 103fe7a59fc83..4d775c3e1e390 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.types import org.apache.spark.annotation.Unstable /** - * The data type representing semi-structured values with arbitrary hierarchical data structures. It - * is intended to store parsed JSON values and most other data types in the system (e.g., it cannot - * store a map with a non-string key type). + * The data type representing semi-structured values with arbitrary hierarchical data structures. + * It is intended to store parsed JSON values and most other data types in the system (e.g., it + * cannot store a map with a non-string key type). * * @since 4.0.0 */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala index 6532a3b220c5b..f69054f2c1fbc 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala @@ -29,18 +29,19 @@ import org.apache.spark.sql.types.YearMonthIntervalType.fieldToString * * `YearMonthIntervalType` represents positive as well as negative year-month intervals. * - * @param startField The leftmost field which the type comprises of. Valid values: - * 0 (YEAR), 1 (MONTH). - * @param endField The rightmost field which the type comprises of. Valid values: - * 0 (YEAR), 1 (MONTH). + * @param startField + * The leftmost field which the type comprises of. Valid values: 0 (YEAR), 1 (MONTH). + * @param endField + * The rightmost field which the type comprises of. Valid values: 0 (YEAR), 1 (MONTH). * * @since 3.2.0 */ @Unstable case class YearMonthIntervalType(startField: Byte, endField: Byte) extends AnsiIntervalType { + /** - * Year-month interval values always occupy 4 bytes. - * The YEAR field is constrained by the upper bound 178956970 to fit to `Int`. + * Year-month interval values always occupy 4 bytes. The YEAR field is constrained by the upper + * bound 178956970 to fit to `Int`. */ override def defaultSize: Int = 4 diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 6852fe09ef96b..1740cbe2957b8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -38,33 +38,33 @@ private[sql] object ArrowUtils { // todo: support more types. /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ - def toArrowType( - dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType = dt match { - case BooleanType => ArrowType.Bool.INSTANCE - case ByteType => new ArrowType.Int(8, true) - case ShortType => new ArrowType.Int(8 * 2, true) - case IntegerType => new ArrowType.Int(8 * 4, true) - case LongType => new ArrowType.Int(8 * 8, true) - case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) - case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case _: StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE - case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE - case _: StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE - case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE - case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale, 8 * 16) - case DateType => new ArrowType.Date(DateUnit.DAY) - case TimestampType if timeZoneId == null => - throw SparkException.internalError("Missing timezoneId where it is mandatory.") - case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) - case TimestampNTZType => - new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) - case NullType => ArrowType.Null.INSTANCE - case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) - case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND) - case CalendarIntervalType => new ArrowType.Interval(IntervalUnit.MONTH_DAY_NANO) - case _ => - throw ExecutionErrors.unsupportedDataTypeError(dt) - } + def toArrowType(dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType = + dt match { + case BooleanType => ArrowType.Bool.INSTANCE + case ByteType => new ArrowType.Int(8, true) + case ShortType => new ArrowType.Int(8 * 2, true) + case IntegerType => new ArrowType.Int(8 * 4, true) + case LongType => new ArrowType.Int(8 * 8, true) + case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case _: StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE + case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE + case _: StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE + case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE + case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale, 8 * 16) + case DateType => new ArrowType.Date(DateUnit.DAY) + case TimestampType if timeZoneId == null => + throw SparkException.internalError("Missing timezoneId where it is mandatory.") + case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) + case TimestampNTZType => + new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) + case NullType => ArrowType.Null.INSTANCE + case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) + case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND) + case CalendarIntervalType => new ArrowType.Interval(IntervalUnit.MONTH_DAY_NANO) + case _ => + throw ExecutionErrors.unsupportedDataTypeError(dt) + } def fromArrowType(dt: ArrowType): DataType = dt match { case ArrowType.Bool.INSTANCE => BooleanType @@ -73,9 +73,11 @@ private[sql] object ArrowUtils { case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType case float: ArrowType.FloatingPoint - if float.getPrecision() == FloatingPointPrecision.SINGLE => FloatType + if float.getPrecision() == FloatingPointPrecision.SINGLE => + FloatType case float: ArrowType.FloatingPoint - if float.getPrecision() == FloatingPointPrecision.DOUBLE => DoubleType + if float.getPrecision() == FloatingPointPrecision.DOUBLE => + DoubleType case ArrowType.Utf8.INSTANCE => StringType case ArrowType.Binary.INSTANCE => BinaryType case ArrowType.LargeUtf8.INSTANCE => StringType @@ -83,13 +85,15 @@ private[sql] object ArrowUtils { case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType case ts: ArrowType.Timestamp - if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null => TimestampNTZType + if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null => + TimestampNTZType case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType case ArrowType.Null.INSTANCE => NullType - case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => YearMonthIntervalType() + case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => + YearMonthIntervalType() case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => DayTimeIntervalType() - case ci: ArrowType.Interval - if ci.getUnit == IntervalUnit.MONTH_DAY_NANO => CalendarIntervalType + case ci: ArrowType.Interval if ci.getUnit == IntervalUnit.MONTH_DAY_NANO => + CalendarIntervalType case _ => throw ExecutionErrors.unsupportedArrowTypeError(dt) } @@ -103,37 +107,54 @@ private[sql] object ArrowUtils { dt match { case ArrayType(elementType, containsNull) => val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) - new Field(name, fieldType, - Seq(toArrowField("element", elementType, containsNull, timeZoneId, - largeVarTypes)).asJava) + new Field( + name, + fieldType, + Seq( + toArrowField("element", elementType, containsNull, timeZoneId, largeVarTypes)).asJava) case StructType(fields) => val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) - new Field(name, fieldType, - fields.map { field => - toArrowField(field.name, field.dataType, field.nullable, timeZoneId, largeVarTypes) - }.toImmutableArraySeq.asJava) + new Field( + name, + fieldType, + fields + .map { field => + toArrowField(field.name, field.dataType, field.nullable, timeZoneId, largeVarTypes) + } + .toImmutableArraySeq + .asJava) case MapType(keyType, valueType, valueContainsNull) => val mapType = new FieldType(nullable, new ArrowType.Map(false), null) // Note: Map Type struct can not be null, Struct Type key field can not be null - new Field(name, mapType, - Seq(toArrowField(MapVector.DATA_VECTOR_NAME, - new StructType() - .add(MapVector.KEY_NAME, keyType, nullable = false) - .add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull), - nullable = false, - timeZoneId, - largeVarTypes)).asJava) + new Field( + name, + mapType, + Seq( + toArrowField( + MapVector.DATA_VECTOR_NAME, + new StructType() + .add(MapVector.KEY_NAME, keyType, nullable = false) + .add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull), + nullable = false, + timeZoneId, + largeVarTypes)).asJava) case udt: UserDefinedType[_] => toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes) case _: VariantType => - val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null, + val fieldType = new FieldType( + nullable, + ArrowType.Struct.INSTANCE, + null, Map("variant" -> "true").asJava) - new Field(name, fieldType, - Seq(toArrowField("value", BinaryType, false, timeZoneId, largeVarTypes), + new Field( + name, + fieldType, + Seq( + toArrowField("value", BinaryType, false, timeZoneId, largeVarTypes), toArrowField("metadata", BinaryType, false, timeZoneId, largeVarTypes)).asJava) case dataType => - val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId, - largeVarTypes), null) + val fieldType = + new FieldType(nullable, toArrowType(dataType, timeZoneId, largeVarTypes), null) new Field(name, fieldType, Seq.empty[Field].asJava) } } @@ -149,9 +170,12 @@ private[sql] object ArrowUtils { val elementField = field.getChildren().get(0) val elementType = fromArrowField(elementField) ArrayType(elementType, containsNull = elementField.isNullable) - case ArrowType.Struct.INSTANCE if field.getMetadata.getOrDefault("variant", "") == "true" - && field.getChildren.asScala.map(_.getName).asJava - .containsAll(Seq("value", "metadata").asJava) => + case ArrowType.Struct.INSTANCE + if field.getMetadata.getOrDefault("variant", "") == "true" + && field.getChildren.asScala + .map(_.getName) + .asJava + .containsAll(Seq("value", "metadata").asJava) => VariantType case ArrowType.Struct.INSTANCE => val fields = field.getChildren().asScala.map { child => @@ -163,7 +187,9 @@ private[sql] object ArrowUtils { } } - /** Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType */ + /** + * Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType + */ def toArrowSchema( schema: StructType, timeZoneId: String, @@ -187,14 +213,17 @@ private[sql] object ArrowUtils { } private def deduplicateFieldNames( - dt: DataType, errorOnDuplicatedFieldNames: Boolean): DataType = dt match { - case udt: UserDefinedType[_] => deduplicateFieldNames(udt.sqlType, errorOnDuplicatedFieldNames) + dt: DataType, + errorOnDuplicatedFieldNames: Boolean): DataType = dt match { + case udt: UserDefinedType[_] => + deduplicateFieldNames(udt.sqlType, errorOnDuplicatedFieldNames) case st @ StructType(fields) => val newNames = if (st.names.toSet.size == st.names.length) { st.names } else { if (errorOnDuplicatedFieldNames) { - throw ExecutionErrors.duplicatedFieldNameInArrowStructError(st.names.toImmutableArraySeq) + throw ExecutionErrors.duplicatedFieldNameInArrowStructError( + st.names.toImmutableArraySeq) } val genNawName = st.names.groupBy(identity).map { case (name, names) if names.length > 1 => @@ -207,7 +236,10 @@ private[sql] object ArrowUtils { val newFields = fields.zip(newNames).map { case (StructField(_, dataType, nullable, metadata), name) => StructField( - name, deduplicateFieldNames(dataType, errorOnDuplicatedFieldNames), nullable, metadata) + name, + deduplicateFieldNames(dataType, errorOnDuplicatedFieldNames), + nullable, + metadata) } StructType(newFields) case ArrayType(elementType, containsNull) => diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java index 07a9409bc57a2..18646f67975c0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java @@ -17,20 +17,25 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.SparkBuildInfo; -import org.apache.spark.sql.errors.QueryExecutionErrors; -import org.apache.spark.unsafe.types.UTF8String; -import org.apache.spark.util.VersionUtils; - -import javax.crypto.Cipher; -import javax.crypto.spec.GCMParameterSpec; -import javax.crypto.spec.IvParameterSpec; -import javax.crypto.spec.SecretKeySpec; import java.nio.ByteBuffer; import java.security.GeneralSecurityException; import java.security.SecureRandom; import java.security.spec.AlgorithmParameterSpec; +import java.text.BreakIterator; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import javax.crypto.Cipher; +import javax.crypto.spec.GCMParameterSpec; +import javax.crypto.spec.IvParameterSpec; +import javax.crypto.spec.SecretKeySpec; +import org.apache.spark.SparkBuildInfo; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.errors.QueryExecutionErrors; +import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.util.VersionUtils; /** * A utility class for constructing expressions. @@ -272,4 +277,42 @@ private static byte[] aesInternal( throw QueryExecutionErrors.aesCryptoError(e.getMessage()); } } + + public static ArrayData getSentences( + UTF8String str, + UTF8String language, + UTF8String country) { + if (str == null) return null; + Locale locale; + if (language != null && country != null) { + locale = new Locale(language.toString(), country.toString()); + } else if (language != null) { + locale = new Locale(language.toString()); + } else { + locale = Locale.US; + } + String sentences = str.toString(); + BreakIterator sentenceInstance = BreakIterator.getSentenceInstance(locale); + sentenceInstance.setText(sentences); + + int sentenceIndex = 0; + List res = new ArrayList<>(); + while (sentenceInstance.next() != BreakIterator.DONE) { + String sentence = sentences.substring(sentenceIndex, sentenceInstance.current()); + sentenceIndex = sentenceInstance.current(); + BreakIterator wordInstance = BreakIterator.getWordInstance(locale); + wordInstance.setText(sentence); + int wordIndex = 0; + List words = new ArrayList<>(); + while (wordInstance.next() != BreakIterator.DONE) { + String word = sentence.substring(wordIndex, wordInstance.current()); + wordIndex = wordInstance.current(); + if (Character.isLetterOrDigit(word.charAt(0))) { + words.add(UTF8String.fromString(word)); + } + } + res.add(new GenericArrayData(words.toArray(new UTF8String[0]))); + } + return new GenericArrayData(res.toArray(new GenericArrayData[0])); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Column.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Column.java index b191438dbc3ee..8b32940d7a657 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Column.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Column.java @@ -53,7 +53,7 @@ static Column create( boolean nullable, String comment, String metadataInJSON) { - return new ColumnImpl(name, dataType, nullable, comment, null, null, metadataInJSON); + return new ColumnImpl(name, dataType, nullable, comment, null, null, null, metadataInJSON); } static Column create( @@ -63,7 +63,8 @@ static Column create( String comment, ColumnDefaultValue defaultValue, String metadataInJSON) { - return new ColumnImpl(name, dataType, nullable, comment, defaultValue, null, metadataInJSON); + return new ColumnImpl(name, dataType, nullable, comment, defaultValue, + null, null, metadataInJSON); } static Column create( @@ -74,7 +75,18 @@ static Column create( String generationExpression, String metadataInJSON) { return new ColumnImpl(name, dataType, nullable, comment, null, - generationExpression, metadataInJSON); + generationExpression, null, metadataInJSON); + } + + static Column create( + String name, + DataType dataType, + boolean nullable, + String comment, + IdentityColumnSpec identityColumnSpec, + String metadataInJSON) { + return new ColumnImpl(name, dataType, nullable, comment, null, + null, identityColumnSpec, metadataInJSON); } /** @@ -113,6 +125,12 @@ static Column create( @Nullable String generationExpression(); + /** + * Returns the identity column specification of this table column. Null means no identity column. + */ + @Nullable + IdentityColumnSpec identityColumnSpec(); + /** * Returns the column metadata in JSON format. */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalogCapability.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalogCapability.java index 5ccb15ff1f0a4..dceac1b484cf2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalogCapability.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalogCapability.java @@ -59,5 +59,23 @@ public enum TableCatalogCapability { * {@link TableCatalog#createTable}. * See {@link Column#defaultValue()}. */ - SUPPORT_COLUMN_DEFAULT_VALUE + SUPPORT_COLUMN_DEFAULT_VALUE, + + /** + * Signals that the TableCatalog supports defining identity columns upon table creation in SQL. + *

    + * Without this capability, any create/replace table statements with an identity column defined + * in the table schema will throw an exception during analysis. + *

    + * An identity column is defined with syntax: + * {@code colName colType GENERATED ALWAYS AS IDENTITY(identityColumnSpec)} + * or + * {@code colName colType GENERATED BY DEFAULT AS IDENTITY(identityColumnSpec)} + * identityColumnSpec is defined with syntax: {@code [START WITH start | INCREMENT BY step]*} + *

    + * IdentitySpec is included in the column definition for APIs like + * {@link TableCatalog#createTable}. + * See {@link Column#identityColumnSpec()}. + */ + SUPPORTS_CREATE_TABLE_WITH_IDENTITY_COLUMNS } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java index 90d531ae21892..18c76833c5879 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java @@ -32,6 +32,11 @@ */ @Evolving public interface ProcedureParameter { + /** + * A field metadata key that indicates whether an argument is passed by name. + */ + String BY_NAME_METADATA_KEY = "BY_NAME"; + /** * Creates a builder for an IN procedure parameter. * diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java index ee9a09055243b..1a91fd21bf07e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java @@ -35,6 +35,12 @@ public interface UnboundProcedure extends Procedure { * validate if the input types are compatible while binding or delegate that to Spark. Regardless, * Spark will always perform the final validation of the arguments and rearrange them as needed * based on {@link BoundProcedure#parameters() reported parameters}. + *

    + * The provided {@code inputType} is based on the procedure arguments. If an argument is passed + * by name, its metadata will indicate this with {@link ProcedureParameter#BY_NAME_METADATA_KEY} + * set to {@code true}. In such cases, the field name will match the name of the target procedure + * parameter. If the argument is not named, {@link ProcedureParameter#BY_NAME_METADATA_KEY} will + * not be set and the name will be assigned randomly. * * @param inputType the input types to bind to * @return the bound procedure that is most suitable for the given input types diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala deleted file mode 100644 index 9b95f74db3a49..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ /dev/null @@ -1,359 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.lang.reflect.Modifier - -import scala.reflect.{classTag, ClassTag} -import scala.reflect.runtime.universe.TypeTag - -import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Cast} -import org.apache.spark.sql.catalyst.expressions.objects.{DecodeUsingSerializer, EncodeUsingSerializer} -import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.types._ - -/** - * Methods for creating an [[Encoder]]. - * - * @since 1.6.0 - */ -object Encoders { - - /** - * An encoder for nullable boolean type. - * The Scala primitive encoder is available as [[scalaBoolean]]. - * @since 1.6.0 - */ - def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder() - - /** - * An encoder for nullable byte type. - * The Scala primitive encoder is available as [[scalaByte]]. - * @since 1.6.0 - */ - def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder() - - /** - * An encoder for nullable short type. - * The Scala primitive encoder is available as [[scalaShort]]. - * @since 1.6.0 - */ - def SHORT: Encoder[java.lang.Short] = ExpressionEncoder() - - /** - * An encoder for nullable int type. - * The Scala primitive encoder is available as [[scalaInt]]. - * @since 1.6.0 - */ - def INT: Encoder[java.lang.Integer] = ExpressionEncoder() - - /** - * An encoder for nullable long type. - * The Scala primitive encoder is available as [[scalaLong]]. - * @since 1.6.0 - */ - def LONG: Encoder[java.lang.Long] = ExpressionEncoder() - - /** - * An encoder for nullable float type. - * The Scala primitive encoder is available as [[scalaFloat]]. - * @since 1.6.0 - */ - def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder() - - /** - * An encoder for nullable double type. - * The Scala primitive encoder is available as [[scalaDouble]]. - * @since 1.6.0 - */ - def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder() - - /** - * An encoder for nullable string type. - * - * @since 1.6.0 - */ - def STRING: Encoder[java.lang.String] = ExpressionEncoder() - - /** - * An encoder for nullable decimal type. - * - * @since 1.6.0 - */ - def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder() - - /** - * An encoder for nullable date type. - * - * @since 1.6.0 - */ - def DATE: Encoder[java.sql.Date] = ExpressionEncoder() - - /** - * Creates an encoder that serializes instances of the `java.time.LocalDate` class - * to the internal representation of nullable Catalyst's DateType. - * - * @since 3.0.0 - */ - def LOCALDATE: Encoder[java.time.LocalDate] = ExpressionEncoder() - - /** - * Creates an encoder that serializes instances of the `java.time.LocalDateTime` class - * to the internal representation of nullable Catalyst's TimestampNTZType. - * - * @since 3.4.0 - */ - def LOCALDATETIME: Encoder[java.time.LocalDateTime] = ExpressionEncoder() - - /** - * An encoder for nullable timestamp type. - * - * @since 1.6.0 - */ - def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder() - - /** - * Creates an encoder that serializes instances of the `java.time.Instant` class - * to the internal representation of nullable Catalyst's TimestampType. - * - * @since 3.0.0 - */ - def INSTANT: Encoder[java.time.Instant] = ExpressionEncoder() - - /** - * An encoder for arrays of bytes. - * - * @since 1.6.1 - */ - def BINARY: Encoder[Array[Byte]] = ExpressionEncoder() - - /** - * Creates an encoder that serializes instances of the `java.time.Duration` class - * to the internal representation of nullable Catalyst's DayTimeIntervalType. - * - * @since 3.2.0 - */ - def DURATION: Encoder[java.time.Duration] = ExpressionEncoder() - - /** - * Creates an encoder that serializes instances of the `java.time.Period` class - * to the internal representation of nullable Catalyst's YearMonthIntervalType. - * - * @since 3.2.0 - */ - def PERIOD: Encoder[java.time.Period] = ExpressionEncoder() - - /** - * Creates an encoder for Java Bean of type T. - * - * T must be publicly accessible. - * - * supported types for java bean field: - * - primitive types: boolean, int, double, etc. - * - boxed types: Boolean, Integer, Double, etc. - * - String - * - java.math.BigDecimal, java.math.BigInteger - * - time related: java.sql.Date, java.sql.Timestamp, java.time.LocalDate, java.time.Instant - * - collection types: array, java.util.List, and map - * - nested java bean. - * - * @since 1.6.0 - */ - def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass) - - /** - * Creates a [[Row]] encoder for schema `schema`. - * - * @since 3.5.0 - */ - def row(schema: StructType): Encoder[Row] = ExpressionEncoder(schema) - - /** - * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. - * This encoder maps T into a single byte array (binary) field. - * - * T must be publicly accessible. - * - * @since 1.6.0 - */ - def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) - - /** - * Creates an encoder that serializes objects of type T using Kryo. - * This encoder maps T into a single byte array (binary) field. - * - * T must be publicly accessible. - * - * @since 1.6.0 - */ - def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) - - /** - * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java - * serialization. This encoder maps T into a single byte array (binary) field. - * - * T must be publicly accessible. - * - * @note This is extremely inefficient and should only be used as the last resort. - * - * @since 1.6.0 - */ - def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) - - /** - * Creates an encoder that serializes objects of type T using generic Java serialization. - * This encoder maps T into a single byte array (binary) field. - * - * T must be publicly accessible. - * - * @note This is extremely inefficient and should only be used as the last resort. - * - * @since 1.6.0 - */ - def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) - - /** Throws an exception if T is not a public class. */ - private def validatePublicClass[T: ClassTag](): Unit = { - if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) { - throw QueryExecutionErrors.notPublicClassError(classTag[T].runtimeClass.getName) - } - } - - /** A way to construct encoders using generic serializers. */ - private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { - if (classTag[T].runtimeClass.isPrimitive) { - throw QueryExecutionErrors.primitiveTypesNotSupportedError() - } - - validatePublicClass[T]() - - ExpressionEncoder[T]( - objSerializer = - EncodeUsingSerializer( - BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo), - objDeserializer = - DecodeUsingSerializer[T]( - Cast(GetColumnByOrdinal(0, BinaryType), BinaryType), - classTag[T], - kryo = useKryo), - clsTag = classTag[T] - ) - } - - /** - * An encoder for 2-ary tuples. - * - * @since 1.6.0 - */ - def tuple[T1, T2]( - e1: Encoder[T1], - e2: Encoder[T2]): Encoder[(T1, T2)] = { - ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2)) - } - - /** - * An encoder for 3-ary tuples. - * - * @since 1.6.0 - */ - def tuple[T1, T2, T3]( - e1: Encoder[T1], - e2: Encoder[T2], - e3: Encoder[T3]): Encoder[(T1, T2, T3)] = { - ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3)) - } - - /** - * An encoder for 4-ary tuples. - * - * @since 1.6.0 - */ - def tuple[T1, T2, T3, T4]( - e1: Encoder[T1], - e2: Encoder[T2], - e3: Encoder[T3], - e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { - ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4)) - } - - /** - * An encoder for 5-ary tuples. - * - * @since 1.6.0 - */ - def tuple[T1, T2, T3, T4, T5]( - e1: Encoder[T1], - e2: Encoder[T2], - e3: Encoder[T3], - e4: Encoder[T4], - e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { - ExpressionEncoder.tuple( - encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5)) - } - - /** - * An encoder for Scala's product type (tuples, case classes, etc). - * @since 2.0.0 - */ - def product[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() - - /** - * An encoder for Scala's primitive int type. - * @since 2.0.0 - */ - def scalaInt: Encoder[Int] = ExpressionEncoder() - - /** - * An encoder for Scala's primitive long type. - * @since 2.0.0 - */ - def scalaLong: Encoder[Long] = ExpressionEncoder() - - /** - * An encoder for Scala's primitive double type. - * @since 2.0.0 - */ - def scalaDouble: Encoder[Double] = ExpressionEncoder() - - /** - * An encoder for Scala's primitive float type. - * @since 2.0.0 - */ - def scalaFloat: Encoder[Float] = ExpressionEncoder() - - /** - * An encoder for Scala's primitive byte type. - * @since 2.0.0 - */ - def scalaByte: Encoder[Byte] = ExpressionEncoder() - - /** - * An encoder for Scala's primitive short type. - * @since 2.0.0 - */ - def scalaShort: Encoder[Short] = ExpressionEncoder() - - /** - * An encoder for Scala's primitive boolean type. - * @since 2.0.0 - */ - def scalaBoolean: Encoder[Boolean] = ExpressionEncoder() - -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 40b49506b58aa..4752434015375 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.{expressions => exprs} import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder} import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast} -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, DecodeUsingSerializer, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, IntervalUtils} import org.apache.spark.sql.types._ @@ -411,6 +411,12 @@ object DeserializerBuildHelper { val result = InitializeJavaBean(newInstance, setters.toMap) exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result) + case TransformingEncoder(tag, _, codec) if codec == JavaSerializationCodec => + DecodeUsingSerializer(path, tag, kryo = false) + + case TransformingEncoder(tag, _, codec) if codec == KryoSerializationCodec => + DecodeUsingSerializer(path, tag, kryo = true) + case TransformingEncoder(tag, encoder, provider) => Invoke( Literal.create(provider(), ObjectType(classOf[Codec[_, _]])), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala index 408bd65333cac..20cf80e88e42a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala @@ -26,7 +26,8 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} * An [[InternalRow]] that projects particular columns from another [[InternalRow]] without copying * the underlying data. */ -case class ProjectingInternalRow(schema: StructType, colOrdinals: Seq[Int]) extends InternalRow { +case class ProjectingInternalRow(schema: StructType, + colOrdinals: IndexedSeq[Int]) extends InternalRow { assert(schema.size == colOrdinals.size) private var row: InternalRow = _ @@ -116,3 +117,9 @@ case class ProjectingInternalRow(schema: StructType, colOrdinals: Seq[Int]) exte row.get(colOrdinals(ordinal), dataType) } } + +object ProjectingInternalRow { + def apply(schema: StructType, colOrdinals: Seq[Int]): ProjectingInternalRow = { + new ProjectingInternalRow(schema, colOrdinals.toIndexedSeq) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index 38bf0651d6f1c..daebe15c298f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -21,7 +21,7 @@ import scala.language.existentials import org.apache.spark.sql.catalyst.{expressions => exprs} import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor} import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData} @@ -398,6 +398,12 @@ object SerializerBuildHelper { } createSerializerForObject(input, serializedFields) + case TransformingEncoder(_, _, codec) if codec == JavaSerializationCodec => + EncodeUsingSerializer(input, kryo = false) + + case TransformingEncoder(_, _, codec) if codec == KryoSerializationCodec => + EncodeUsingSerializer(input, kryo = true) + case TransformingEncoder(_, encoder, codecProvider) => val encoded = Invoke( Literal(codecProvider(), ObjectType(classOf[Codec[_, _]])), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0164af945ca28..9e5b1d1254c87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import scala.util.{Failure, Random, Success, Try} -import org.apache.spark.{SparkException, SparkUnsupportedOperationException} +import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ @@ -50,6 +50,7 @@ import org.apache.spark.sql.connector.catalog.{View => _, _} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.{After, ColumnPosition} import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, ScalarFunction, UnboundFunction} +import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter, UnboundProcedure} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -310,6 +311,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ExtractGenerator :: ResolveGenerate :: ResolveFunctions :: + ResolveProcedures :: + BindProcedures :: ResolveTableSpec :: ResolveAliases :: ResolveSubquery :: @@ -2611,6 +2614,66 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } + /** + * A rule that resolves procedures. + */ + object ResolveProcedures extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( + _.containsPattern(UNRESOLVED_PROCEDURE), ruleId) { + case Call(UnresolvedProcedure(CatalogAndIdentifier(catalog, ident)), args, execute) => + val procedureCatalog = catalog.asProcedureCatalog + val procedure = load(procedureCatalog, ident) + Call(ResolvedProcedure(procedureCatalog, ident, procedure), args, execute) + } + + private def load(catalog: ProcedureCatalog, ident: Identifier): UnboundProcedure = { + try { + catalog.loadProcedure(ident) + } catch { + case e: Exception if !e.isInstanceOf[SparkThrowable] => + val nameParts = catalog.name +: ident.asMultipartIdentifier + throw QueryCompilationErrors.failedToLoadRoutineError(nameParts, e) + } + } + } + + /** + * A rule that binds procedures to the input types and rearranges arguments as needed. + */ + object BindProcedures extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case Call(ResolvedProcedure(catalog, ident, unbound: UnboundProcedure), args, execute) + if args.forall(_.resolved) => + val inputType = extractInputType(args) + val bound = unbound.bind(inputType) + validateParameterModes(bound) + val rearrangedArgs = NamedParametersSupport.defaultRearrange(bound, args) + Call(ResolvedProcedure(catalog, ident, bound), rearrangedArgs, execute) + } + + private def extractInputType(args: Seq[Expression]): StructType = { + val fields = args.zipWithIndex.map { + case (NamedArgumentExpression(name, value), _) => + StructField(name, value.dataType, value.nullable, byNameMetadata) + case (arg, index) => + StructField(s"param$index", arg.dataType, arg.nullable) + } + StructType(fields) + } + + private def byNameMetadata: Metadata = { + new MetadataBuilder() + .putBoolean(ProcedureParameter.BY_NAME_METADATA_KEY, value = true) + .build() + } + + private def validateParameterModes(procedure: BoundProcedure): Unit = { + procedure.parameters.find(_.mode != ProcedureParameter.Mode.IN).foreach { param => + throw SparkException.internalError(s"Unsupported parameter mode: ${param.mode}") + } + } + } + /** * This rule resolves and rewrites subqueries inside expressions. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 17b1c4e249f57..3afe0ec8e9a7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -77,6 +77,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { override def typeCoercionRules: List[Rule[LogicalPlan]] = UnpivotCoercion :: WidenSetOperationTypes :: + ProcedureArgumentCoercion :: new AnsiCombinedTypeCoercionRule( CollationTypeCasts :: InConversion :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a9fbe548ba39e..5a9d5cd87ecc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -250,6 +250,9 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB context = u.origin.getQueryContext, summary = u.origin.context.summary) + case u: UnresolvedInlineTable if unresolvedInlineTableContainsScalarSubquery(u) => + throw QueryCompilationErrors.inlineTableContainsScalarSubquery(u) + case command: V2PartitionCommand => command.table match { case r @ ResolvedTable(_, _, table, _) => table match { @@ -673,6 +676,14 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB varName, c.defaultExpr.originalSQL) + case c: Call if c.resolved && c.bound && c.checkArgTypes().isFailure => + c.checkArgTypes() match { + case mismatch: TypeCheckResult.DataTypeMismatch => + c.dataTypeMismatch("CALL", mismatch) + case _ => + throw SparkException.internalError("Invalid input for procedure") + } + case _ => // Falls back to the following checks } @@ -1559,6 +1570,15 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB case _ => } } + + private def unresolvedInlineTableContainsScalarSubquery( + unresolvedInlineTable: UnresolvedInlineTable) = { + unresolvedInlineTable.rows.exists { row => + row.exists { expression => + expression.exists(_.isInstanceOf[ScalarSubquery]) + } + } + } } // a heap of the preempted error that only keeps the top priority element, representing the sole diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index 0fa11b9c45038..e22a4b941b30c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -105,11 +105,15 @@ object DeduplicateRelations extends Rule[LogicalPlan] { case p: LogicalPlan if p.isStreaming => (plan, false) case m: MultiInstanceRelation => - deduplicateAndRenew[LogicalPlan with MultiInstanceRelation]( - existingRelations, - m, - _.output.map(_.exprId.id), - node => node.newInstance().asInstanceOf[LogicalPlan with MultiInstanceRelation]) + val planWrapper = RelationWrapper(m.getClass, m.output.map(_.exprId.id)) + if (existingRelations.contains(planWrapper)) { + val newNode = m.newInstance() + newNode.copyTagsFrom(m) + (newNode, true) + } else { + existingRelations.add(planWrapper) + (m, false) + } case p: Project => deduplicateAndRenew[Project]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index dfe1bd12bb7ff..d03d8114e9976 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -384,7 +384,9 @@ object FunctionRegistry { expression[Rand]("rand"), expression[Rand]("random", true, Some("3.0.0")), expression[Randn]("randn"), + expression[RandStr]("randstr"), expression[Stack]("stack"), + expression[Uniform]("uniform"), expression[ZeroIfNull]("zeroifnull"), CaseWhen.registryEntry, @@ -839,6 +841,7 @@ object FunctionRegistry { expressionBuilder("try_variant_get", TryVariantGetExpressionBuilder), expression[SchemaOfVariant]("schema_of_variant"), expression[SchemaOfVariantAgg]("schema_of_variant_agg"), + expression[ToVariantObject]("to_variant_object"), // cast expression[Cast]("cast"), @@ -1157,6 +1160,7 @@ object TableFunctionRegistry { generator[PosExplode]("posexplode"), generator[PosExplode]("posexplode_outer", outer = true), generator[Stack]("stack"), + generator[Collations]("collations"), generator[SQLKeywords]("sql_keywords"), generator[VariantExplode]("variant_explode"), generator[VariantExplode]("variant_explode_outer", outer = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 08c5b3531b4c8..5983346ff1e27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.AlwaysProcess import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, StringTypeAnyCollation} @@ -202,6 +203,20 @@ abstract class TypeCoercionBase { } } + /** + * A type coercion rule that implicitly casts procedure arguments to expected types. + */ + object ProcedureArgumentCoercion extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case c @ Call(ResolvedProcedure(_, _, procedure: BoundProcedure), args, _) if c.resolved => + val expectedDataTypes = procedure.parameters.map(_.dataType) + val coercedArgs = args.zip(expectedDataTypes).map { + case (arg, expectedType) => implicitCast(arg, expectedType).getOrElse(arg) + } + c.copy(args = coercedArgs) + } + } + /** * Widens the data types of the [[Unpivot]] values. */ @@ -838,6 +853,7 @@ object TypeCoercion extends TypeCoercionBase { override def typeCoercionRules: List[Rule[LogicalPlan]] = UnpivotCoercion :: WidenSetOperationTypes :: + ProcedureArgumentCoercion :: new CombinedTypeCoercionRule( CollationTypeCasts :: InConversion :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index c0689eb121679..daab9e4d78bf5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -67,9 +67,13 @@ package object analysis { } def dataTypeMismatch(expr: Expression, mismatch: DataTypeMismatch): Nothing = { + dataTypeMismatch(toSQLExpr(expr), mismatch) + } + + def dataTypeMismatch(sqlExpr: String, mismatch: DataTypeMismatch): Nothing = { throw new AnalysisException( errorClass = s"DATATYPE_MISMATCH.${mismatch.errorSubClass}", - messageParameters = mismatch.messageParameters + ("sqlExpr" -> toSQLExpr(expr)), + messageParameters = mismatch.messageParameters + ("sqlExpr" -> sqlExpr), origin = t.origin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 60d979e9c7afb..6f445b1e88d70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -974,3 +974,13 @@ case object UnresolvedWithinGroup extends LeafExpression with Unevaluable { override def dataType: DataType = throw new UnresolvedException("dataType") override lazy val resolved = false } + +case class UnresolvedTranspose( + indices: Seq[Expression], + child: LogicalPlan +) extends UnresolvedUnaryNode { + final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_TRANSPOSE) + + override protected def withNewChildInternal(newChild: LogicalPlan): UnresolvedTranspose = + copy(child = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala index ecdf40e87a894..dee78b8f03af4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala @@ -23,13 +23,14 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, LeafExpression, Unevaluable} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} -import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_FUNC} +import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_FUNC, UNRESOLVED_PROCEDURE} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, Table, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, ProcedureCatalog, Table, TableCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.catalog.functions.UnboundFunction +import org.apache.spark.sql.connector.catalog.procedures.Procedure import org.apache.spark.sql.types.{DataType, StructField} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ArrayImplicits._ @@ -135,6 +136,12 @@ case class UnresolvedFunctionName( case class UnresolvedIdentifier(nameParts: Seq[String], allowTemp: Boolean = false) extends UnresolvedLeafNode +/** + * A procedure identifier that should be resolved into [[ResolvedProcedure]]. + */ +case class UnresolvedProcedure(nameParts: Seq[String]) extends UnresolvedLeafNode { + final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_PROCEDURE) +} /** * A resolved leaf node whose statistics has no meaning. @@ -192,6 +199,12 @@ case class ResolvedFieldName(path: Seq[String], field: StructField) extends Fiel case class ResolvedFieldPosition(position: ColumnPosition) extends FieldPosition +case class ResolvedProcedure( + catalog: ProcedureCatalog, + ident: Identifier, + procedure: Procedure) extends LeafNodeWithoutStats { + override def output: Seq[Attribute] = Nil +} /** * A plan containing resolved persistent views. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 2c27da3cf6e15..5444ab6845867 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -21,12 +21,12 @@ import java.util.Locale import scala.util.control.Exception.allCatch +import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.util.{DateFormatter, TimestampFormatter} import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT -import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -138,7 +138,7 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { case BooleanType => tryParseBoolean(field) case StringType => StringType case other: DataType => - throw QueryExecutionErrors.dataTypeUnexpectedError(other) + throw SparkException.internalError(s"Unexpected data type $other") } compatibleType(typeSoFar, typeElemInfer).getOrElse(StringType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 0b5ce65fed6df..d7d53230470d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.{DeserializerBuildHelper, InternalRow, Java import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, Invoke, NewInstance} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, NewInstance} import org.apache.spark.sql.catalyst.optimizer.{ReassignLambdaVariableID, SimplifyCasts} import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LeafNode, LocalRelation} import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -54,9 +54,9 @@ object ExpressionEncoder { def apply[T](enc: AgnosticEncoder[T]): ExpressionEncoder[T] = { new ExpressionEncoder[T]( + enc, SerializerBuildHelper.createSerializer(enc), - DeserializerBuildHelper.createDeserializer(enc), - enc.clsTag) + DeserializerBuildHelper.createDeserializer(enc)) } def apply(schema: StructType): ExpressionEncoder[Row] = apply(schema, lenient = false) @@ -70,107 +70,6 @@ object ExpressionEncoder { apply(JavaTypeInference.encoderFor(beanClass)) } - /** - * Given a set of N encoders, constructs a new encoder that produce objects as items in an - * N-tuple. Note that these encoders should be unresolved so that information about - * name/positional binding is preserved. - * When `useNullSafeDeserializer` is true, the deserialization result for a child will be null if - * the input is null. It is false by default as most deserializers handle null input properly and - * don't require an extra null check. Some of them are null-tolerant, such as the deserializer for - * `Option[T]`, and we must not set it to true in this case. - */ - def tuple( - encoders: Seq[ExpressionEncoder[_]], - useNullSafeDeserializer: Boolean = false): ExpressionEncoder[_] = { - if (encoders.length > 22) { - throw QueryExecutionErrors.elementsOfTupleExceedLimitError() - } - - encoders.foreach(_.assertUnresolved()) - - val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - - val newSerializerInput = BoundReference(0, ObjectType(cls), nullable = true) - val serializers = encoders.zipWithIndex.map { case (enc, index) => - val boundRefs = enc.objSerializer.collect { case b: BoundReference => b }.distinct - assert(boundRefs.size == 1, "object serializer should have only one bound reference but " + - s"there are ${boundRefs.size}") - - val originalInputObject = boundRefs.head - val newInputObject = Invoke( - newSerializerInput, - s"_${index + 1}", - originalInputObject.dataType, - returnNullable = originalInputObject.nullable) - - val newSerializer = enc.objSerializer.transformUp { - case BoundReference(0, _, _) => newInputObject - } - - Alias(newSerializer, s"_${index + 1}")() - } - val newSerializer = CreateStruct(serializers) - - def nullSafe(input: Expression, result: Expression): Expression = { - If(IsNull(input), Literal.create(null, result.dataType), result) - } - - val newDeserializerInput = GetColumnByOrdinal(0, newSerializer.dataType) - val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => - val getColExprs = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }.distinct - assert(getColExprs.size == 1, "object deserializer should have only one " + - s"`GetColumnByOrdinal`, but there are ${getColExprs.size}") - - val input = GetStructField(newDeserializerInput, index) - val childDeserializer = enc.objDeserializer.transformUp { - case GetColumnByOrdinal(0, _) => input - } - - if (useNullSafeDeserializer && enc.objSerializer.nullable) { - nullSafe(input, childDeserializer) - } else { - childDeserializer - } - } - val newDeserializer = - NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false) - - new ExpressionEncoder[Any]( - nullSafe(newSerializerInput, newSerializer), - nullSafe(newDeserializerInput, newDeserializer), - ClassTag(cls)) - } - - // Tuple1 - def tuple[T](e: ExpressionEncoder[T]): ExpressionEncoder[Tuple1[T]] = - tuple(Seq(e)).asInstanceOf[ExpressionEncoder[Tuple1[T]]] - - def tuple[T1, T2]( - e1: ExpressionEncoder[T1], - e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] = - tuple(Seq(e1, e2)).asInstanceOf[ExpressionEncoder[(T1, T2)]] - - def tuple[T1, T2, T3]( - e1: ExpressionEncoder[T1], - e2: ExpressionEncoder[T2], - e3: ExpressionEncoder[T3]): ExpressionEncoder[(T1, T2, T3)] = - tuple(Seq(e1, e2, e3)).asInstanceOf[ExpressionEncoder[(T1, T2, T3)]] - - def tuple[T1, T2, T3, T4]( - e1: ExpressionEncoder[T1], - e2: ExpressionEncoder[T2], - e3: ExpressionEncoder[T3], - e4: ExpressionEncoder[T4]): ExpressionEncoder[(T1, T2, T3, T4)] = - tuple(Seq(e1, e2, e3, e4)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]] - - def tuple[T1, T2, T3, T4, T5]( - e1: ExpressionEncoder[T1], - e2: ExpressionEncoder[T2], - e3: ExpressionEncoder[T3], - e4: ExpressionEncoder[T4], - e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] = - tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] - private val anyObjectType = ObjectType(classOf[Any]) /** @@ -228,6 +127,7 @@ object ExpressionEncoder { * A generic encoder for JVM objects that uses Catalyst Expressions for a `serializer` * and a `deserializer`. * + * @param encoder the `AgnosticEncoder` for type `T`. * @param objSerializer An expression that can be used to encode a raw object to corresponding * Spark SQL representation that can be a primitive column, array, map or a * struct. This represents how Spark SQL generally serializes an object of @@ -236,13 +136,15 @@ object ExpressionEncoder { * representation. This represents how Spark SQL generally deserializes * a serialized value in Spark SQL representation back to an object of * type `T`. - * @param clsTag A classtag for `T`. */ case class ExpressionEncoder[T]( + encoder: AgnosticEncoder[T], objSerializer: Expression, - objDeserializer: Expression, - clsTag: ClassTag[T]) - extends Encoder[T] { + objDeserializer: Expression) + extends Encoder[T] + with ToAgnosticEncoder[T] { + + override def clsTag: ClassTag[T] = encoder.clsTag /** * A sequence of expressions, one for each top-level field that can be used to diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/KryoSerializationCodecImpl.scala similarity index 57% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/KryoSerializationCodecImpl.scala index f3e0c0aca29ca..49c7b41f77472 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/KryoSerializationCodecImpl.scala @@ -14,20 +14,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.spark.sql.catalyst.encoders +import java.nio.ByteBuffer -package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.expressions.objects.SerializerSupport -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.util.quoteNameParts -import org.apache.spark.sql.connector.catalog.Identifier -import org.apache.spark.util.ArrayImplicits._ +/** + * A codec that uses Kryo to (de)serialize arbitrary objects to and from a byte array. + */ +class KryoSerializationCodecImpl extends Codec[Any, Array[Byte]] { + private val serializer = SerializerSupport.newSerializer(useKryo = true) + override def encode(in: Any): Array[Byte] = + serializer.serialize(in).array() -class CannotReplaceMissingTableException( - tableIdentifier: Identifier, - cause: Option[Throwable] = None) - extends AnalysisException( - errorClass = "TABLE_OR_VIEW_NOT_FOUND", - messageParameters = Map("relationName" - -> quoteNameParts((tableIdentifier.namespace :+ tableIdentifier.name).toImmutableArraySeq)), - cause = cause) + override def decode(out: Array[Byte]): Any = + serializer.deserialize(ByteBuffer.wrap(out)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala index de1122da646b7..c226e48c6be5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.internal.SQLConf // scalastyle:off line.size.limit @ExpressionDescription( - usage = "Usage: input [NOT] BETWEEN lower AND upper - evaluate if `input` is [not] in between `lower` and `upper`", + usage = "input [NOT] _FUNC_ lower AND upper - evaluate if `input` is [not] in between `lower` and `upper`", examples = """ Examples: > SELECT 0.5 _FUNC_ 0.1 AND 1.0; @@ -33,7 +33,7 @@ import org.apache.spark.sql.internal.SQLConf * lower - Lower bound of the between check. * upper - Upper bound of the between check. """, - since = "4.0.0", + since = "1.0.0", group = "conditional_funcs") case class Between private(input: Expression, lower: Expression, upper: Expression, replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 4a2b4b28e690e..7a2799e99fe2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -128,7 +128,10 @@ object Cast extends QueryErrorsBase { case (TimestampType, _: NumericType) => true case (VariantType, _) => variant.VariantGet.checkDataType(to) - case (_, VariantType) => variant.VariantGet.checkDataType(from) + // Structs and Maps can't be cast to Variants since the Variant spec does not yet contain + // lossless equivalents for these types. The `to_variant_object` expression can be used instead + // to convert data of these types to Variant Objects. + case (_, VariantType) => variant.VariantGet.checkDataType(from, allowStructsAndMaps = false) case (ArrayType(fromType, fn), ArrayType(toType, tn)) => canAnsiCast(fromType, toType) && resolvableNullability(fn, tn) @@ -237,7 +240,10 @@ object Cast extends QueryErrorsBase { case (_: NumericType, _: NumericType) => true case (VariantType, _) => variant.VariantGet.checkDataType(to) - case (_, VariantType) => variant.VariantGet.checkDataType(from) + // Structs and Maps can't be cast to Variants since the Variant spec does not yet contain + // lossless equivalents for these types. The `to_variant_object` expression can be used instead + // to convert data of these types to Variant Objects. + case (_, VariantType) => variant.VariantGet.checkDataType(from, allowStructsAndMaps = false) case (ArrayType(fromType, fn), ArrayType(toType, tn)) => canCast(fromType, toType) && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PipeSelect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PipeSelect.scala new file mode 100644 index 0000000000000..0b5479cc8f0ee --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PipeSelect.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction +import org.apache.spark.sql.catalyst.trees.TreePattern.{PIPE_OPERATOR_SELECT, RUNTIME_REPLACEABLE, TreePattern} +import org.apache.spark.sql.errors.QueryCompilationErrors + +/** + * Represents a SELECT clause when used with the |> SQL pipe operator. + * We use this to make sure that no aggregate functions exist in the SELECT expressions. + */ +case class PipeSelect(child: Expression) + extends UnaryExpression with RuntimeReplaceable { + final override val nodePatterns: Seq[TreePattern] = Seq(PIPE_OPERATOR_SELECT, RUNTIME_REPLACEABLE) + override def withNewChildInternal(newChild: Expression): Expression = PipeSelect(newChild) + override lazy val replacement: Expression = { + def visit(e: Expression): Unit = e match { + case a: AggregateFunction => + // If we used the pipe operator |> SELECT clause to specify an aggregate function, this is + // invalid; return an error message instructing the user to use the pipe operator + // |> AGGREGATE clause for this purpose instead. + throw QueryCompilationErrors.pipeOperatorSelectContainsAggregateFunction(a) + case _: WindowExpression => + // Window functions are allowed in pipe SELECT operators, so do not traverse into children. + case _ => + e.children.foreach(visit) + } + visit(child) + child + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 4987e31b49911..8ad062ab0e2f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import com.google.common.primitives.{Doubles, Ints, Longs} +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} @@ -32,7 +33,6 @@ import org.apache.spark.sql.catalyst.types.PhysicalNumericType import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats} -import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ @@ -189,7 +189,7 @@ case class ApproximatePercentile( PhysicalNumericType.numeric(n) .toDouble(value.asInstanceOf[PhysicalNumericType#InternalType]) case other: DataType => - throw QueryExecutionErrors.dataTypeUnexpectedError(other) + throw SparkException.internalError(s"Unexpected data type $other") } buffer.add(doubleValue) } @@ -214,7 +214,7 @@ case class ApproximatePercentile( case DoubleType => doubleResult case _: DecimalType => doubleResult.map(Decimal(_)) case other: DataType => - throw QueryExecutionErrors.dataTypeUnexpectedError(other) + throw SparkException.internalError(s"Unexpected data type $other") } if (result.length == 0) { null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala index ba26c5a1022d0..eda2c742ab4b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala @@ -143,7 +143,8 @@ case class HistogramNumeric( if (buffer.getUsedBins < 1) { null } else { - val result = (0 until buffer.getUsedBins).map { index => + val array = new Array[AnyRef](buffer.getUsedBins) + (0 until buffer.getUsedBins).foreach { index => // Note that the 'coord.x' and 'coord.y' have double-precision floating point type here. val coord = buffer.getBin(index) if (propagateInputType) { @@ -163,16 +164,16 @@ case class HistogramNumeric( coord.x.toLong case _ => coord.x } - InternalRow.apply(result, coord.y) + array(index) = InternalRow.apply(result, coord.y) } else { // Otherwise, just apply the double-precision values in 'coord.x' and 'coord.y' to the // output row directly. In this case: 'SELECT histogram_numeric(val, 3) // FROM VALUES (0L), (1L), (2L), (10L) AS tab(col)' returns an array of structs where the // first field has DoubleType. - InternalRow.apply(coord.x, coord.y) + array(index) = InternalRow.apply(coord.x, coord.y) } } - new GenericArrayData(result) + new GenericArrayData(array) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index e77622a26d90a..c593c8bfb8341 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -170,7 +170,7 @@ case class CollectSet( override def eval(buffer: mutable.HashSet[Any]): Any = { val array = child.dataType match { case BinaryType => - buffer.iterator.map(_.asInstanceOf[ArrayData].toByteArray()).toArray + buffer.iterator.map(_.asInstanceOf[ArrayData].toByteArray()).toArray[Any] case _ => buffer.toArray } new GenericArrayData(array) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 375a2bde59230..5cdd3c7eb62d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -433,7 +433,7 @@ case class ArraysZip(children: Seq[Expression], names: Seq[Expression]) inputArrays.map(_.numElements()).max } - val result = new Array[InternalRow](biggestCardinality) + val result = new Array[AnyRef](biggestCardinality) val zippedArrs: Seq[(ArrayData, Int)] = inputArrays.zipWithIndex for (i <- 0 until biggestCardinality) { @@ -1058,20 +1058,26 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def checkInputDataTypes(): TypeCheckResult = base.dataType match { case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => - ascendingOrder match { - case Literal(_: Boolean, BooleanType) => - TypeCheckResult.TypeCheckSuccess - case _ => - DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> ordinalNumber(1), - "requiredType" -> toSQLType(BooleanType), - "inputSql" -> toSQLExpr(ascendingOrder), - "inputType" -> toSQLType(ascendingOrder.dataType)) - ) + if (!ascendingOrder.foldable) { + DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> toSQLId("ascendingOrder"), + "inputType" -> toSQLType(ascendingOrder.dataType), + "inputExpr" -> toSQLExpr(ascendingOrder))) + } else if (ascendingOrder.dataType != BooleanType) { + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(1), + "requiredType" -> toSQLType(BooleanType), + "inputSql" -> toSQLExpr(ascendingOrder), + "inputType" -> toSQLType(ascendingOrder.dataType)) + ) + } else { + TypeCheckResult.TypeCheckSuccess } - case ArrayType(dt, _) => + case ArrayType(_, _) => DataTypeMismatch( errorSubClass = "INVALID_ORDERING_TYPE", messageParameters = Map( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index ba1beab28d9a7..b8b47f2763f5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeNonCSAICollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -579,7 +579,7 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E override def third: Expression = keyValueDelim override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) override def dataType: DataType = MapType(first.dataType, first.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index ca74fefb9c032..56ecbf550e45c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistryBase, TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TernaryLike @@ -188,6 +189,13 @@ case class CaseWhen( } override def nullable: Boolean = { + if (branches.exists(_._1 == TrueLiteral)) { + // if any of the branch is always true + // nullability check should only be related to branches + // before the TrueLiteral and value of the first TrueLiteral branch + val (h, t) = branches.span(_._1 != TrueLiteral) + return h.exists(_._2.nullable) || t.head._2.nullable + } // Result is nullable if any of the branch is nullable, or if the else value is nullable branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 2cc88a25f465d..dc58352a1b362 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable +import scala.jdk.CollectionConverters.CollectionHasAsScala import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -28,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter} import org.apache.spark.sql.catalyst.trees.TreePattern.{GENERATOR, TreePattern} -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, MapData} import org.apache.spark.sql.catalyst.util.SQLKeywordUtils._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf @@ -618,3 +619,44 @@ case class SQLKeywords() extends LeafExpression with Generator with CodegenFallb override def prettyName: String = "sql_keywords" } + +@ExpressionDescription( + usage = """_FUNC_() - Get all of the Spark SQL string collations""", + examples = """ + Examples: + > SELECT * FROM _FUNC_() WHERE NAME = 'UTF8_BINARY'; + SYSTEM BUILTIN UTF8_BINARY NULL NULL ACCENT_SENSITIVE CASE_SENSITIVE NO_PAD NULL + """, + since = "4.0.0", + group = "generator_funcs") +case class Collations() extends LeafExpression with Generator with CodegenFallback { + override def elementSchema: StructType = new StructType() + .add("CATALOG", StringType, nullable = false) + .add("SCHEMA", StringType, nullable = false) + .add("NAME", StringType, nullable = false) + .add("LANGUAGE", StringType) + .add("COUNTRY", StringType) + .add("ACCENT_SENSITIVITY", StringType, nullable = false) + .add("CASE_SENSITIVITY", StringType, nullable = false) + .add("PAD_ATTRIBUTE", StringType, nullable = false) + .add("ICU_VERSION", StringType) + + override def eval(input: InternalRow): IterableOnce[InternalRow] = { + CollationFactory.listCollations().asScala.map(CollationFactory.loadCollationMeta).map { m => + InternalRow( + UTF8String.fromString(m.catalog), + UTF8String.fromString(m.schema), + UTF8String.fromString(m.collationName), + UTF8String.fromString(m.language), + UTF8String.fromString(m.country), + UTF8String.fromString( + if (m.accentSensitivity) "ACCENT_SENSITIVE" else "ACCENT_INSENSITIVE"), + UTF8String.fromString( + if (m.caseSensitivity) "CASE_SENSITIVE" else "CASE_INSENSITIVE"), + UTF8String.fromString(m.padAttribute), + UTF8String.fromString(m.icuVersion)) + } + } + + override def prettyName: String = "collations" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 7005d663a3f96..2037eb22fede6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -632,7 +632,8 @@ case class JsonToStructs( schema: DataType, options: Map[String, String], child: Expression, - timeZoneId: Option[String] = None) + timeZoneId: Option[String] = None, + variantAllowDuplicateKeys: Boolean = SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS)) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback @@ -719,7 +720,8 @@ case class JsonToStructs( override def nullSafeEval(json: Any): Any = nullableSchema match { case _: VariantType => - VariantExpressionEvalUtils.parseJson(json.asInstanceOf[UTF8String]) + VariantExpressionEvalUtils.parseJson(json.asInstanceOf[UTF8String], + allowDuplicateKeys = variantAllowDuplicateKeys) case _ => converter(parser.parse(json.asInstanceOf[UTF8String])) } @@ -737,6 +739,12 @@ case class JsonToStructs( copy(child = newChild) } +object JsonToStructs { + def unapply( + j: JsonToStructs): Option[(DataType, Map[String, String], Expression, Option[String])] = + Some((j.schema, j.options, j.child, j.timeZoneId)) +} + /** * Converts a [[StructType]], [[ArrayType]] or [[MapType]] to a JSON output string. */ @@ -1072,7 +1080,7 @@ case class JsonObjectKeys(child: Expression) extends UnaryExpression with Codege // skip all the children of inner object or array parser.skipChildren() } - new GenericArrayData(arrayBufferOfKeys.toArray) + new GenericArrayData(arrayBufferOfKeys.toArray[Any]) } override protected def withNewChildInternal(newChild: Expression): JsonObjectKeys = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 00274a16b888b..ddba820414ae4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1293,7 +1293,7 @@ sealed trait BitShiftOperation * @param right number of bits to left shift. */ @ExpressionDescription( - usage = "base << exp - Bitwise left shift.", + usage = "base _FUNC_ exp - Bitwise left shift.", examples = """ Examples: > SELECT shiftleft(2, 1); @@ -1322,7 +1322,7 @@ case class ShiftLeft(left: Expression, right: Expression) extends BitShiftOperat * @param right number of bits to right shift. */ @ExpressionDescription( - usage = "base >> expr - Bitwise (signed) right shift.", + usage = "base _FUNC_ expr - Bitwise (signed) right shift.", examples = """ Examples: > SELECT shiftright(4, 1); @@ -1350,7 +1350,7 @@ case class ShiftRight(left: Expression, right: Expression) extends BitShiftOpera * @param right the number of bits to right shift. */ @ExpressionDescription( - usage = "base >>> expr - Bitwise unsigned right shift.", + usage = "base _FUNC_ expr - Bitwise unsigned right shift.", examples = """ Examples: > SELECT shiftrightunsigned(4, 1); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index f5db972a28643..f329f8346b0de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -17,13 +17,18 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedSeed} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.TreePattern.{EXPRESSION_WITH_RANDOM_SEED, TreePattern} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike} +import org.apache.spark.sql.catalyst.trees.TreePattern.{EXPRESSION_WITH_RANDOM_SEED, RUNTIME_REPLACEABLE, TreePattern} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.random.XORShiftRandom /** @@ -33,8 +38,7 @@ import org.apache.spark.util.random.XORShiftRandom * * Since this expression is stateful, it cannot be a case object. */ -abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic - with ExpressionWithRandomSeed { +trait RDG extends Expression with ExpressionWithRandomSeed { /** * Record ID within each partition. By being transient, the Random Number Generator is * reset every time we serialize and deserialize and initialize it. @@ -43,12 +47,6 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm override def stateful: Boolean = true - override protected def initializeInternal(partitionIndex: Int): Unit = { - rng = new XORShiftRandom(seed + partitionIndex) - } - - override def seedExpression: Expression = child - @transient protected lazy val seed: Long = seedExpression match { case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int] case e if e.dataType == LongType => e.eval().asInstanceOf[Long] @@ -57,6 +55,15 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm override def nullable: Boolean = false override def dataType: DataType = DoubleType +} + +abstract class NondeterministicUnaryRDG + extends RDG with UnaryLike[Expression] with Nondeterministic with ExpectsInputTypes { + override def seedExpression: Expression = child + + override protected def initializeInternal(partitionIndex: Int): Unit = { + rng = new XORShiftRandom(seed + partitionIndex) + } override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType)) } @@ -74,6 +81,7 @@ trait ExpressionWithRandomSeed extends Expression { private[catalyst] object ExpressionWithRandomSeed { def expressionToSeed(e: Expression, source: String): Option[Long] = e match { + case IntegerLiteral(seed) => Some(seed) case LongLiteral(seed) => Some(seed) case Literal(null, _) => None case _ => throw QueryCompilationErrors.invalidRandomSeedParameter(source, e) @@ -99,7 +107,7 @@ private[catalyst] object ExpressionWithRandomSeed { since = "1.5.0", group = "math_funcs") // scalastyle:on line.size.limit -case class Rand(child: Expression, hideSeed: Boolean = false) extends RDG { +case class Rand(child: Expression, hideSeed: Boolean = false) extends NondeterministicUnaryRDG { def this() = this(UnresolvedSeed, true) @@ -150,7 +158,7 @@ object Rand { since = "1.5.0", group = "math_funcs") // scalastyle:on line.size.limit -case class Randn(child: Expression, hideSeed: Boolean = false) extends RDG { +case class Randn(child: Expression, hideSeed: Boolean = false) extends NondeterministicUnaryRDG { def this() = this(UnresolvedSeed, true) @@ -181,3 +189,236 @@ case class Randn(child: Expression, hideSeed: Boolean = false) extends RDG { object Randn { def apply(seed: Long): Randn = Randn(Literal(seed, LongType)) } + +@ExpressionDescription( + usage = """ + _FUNC_(min, max[, seed]) - Returns a random value with independent and identically + distributed (i.i.d.) values with the specified range of numbers. The random seed is optional. + The provided numbers specifying the minimum and maximum values of the range must be constant. + If both of these numbers are integers, then the result will also be an integer. Otherwise if + one or both of these are floating-point numbers, then the result will also be a floating-point + number. + """, + examples = """ + Examples: + > SELECT _FUNC_(10, 20, 0) > 0 AS result; + true + """, + since = "4.0.0", + group = "math_funcs") +case class Uniform(min: Expression, max: Expression, seedExpression: Expression) + extends RuntimeReplaceable with TernaryLike[Expression] with RDG { + def this(min: Expression, max: Expression) = this(min, max, UnresolvedSeed) + + final override lazy val deterministic: Boolean = false + override val nodePatterns: Seq[TreePattern] = + Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED) + + override val dataType: DataType = { + val first = min.dataType + val second = max.dataType + (min.dataType, max.dataType) match { + case _ if !seedExpression.resolved || seedExpression.dataType == NullType => + NullType + case (_, NullType) | (NullType, _) => NullType + case (_, LongType) | (LongType, _) + if Seq(first, second).forall(integer) => LongType + case (_, IntegerType) | (IntegerType, _) + if Seq(first, second).forall(integer) => IntegerType + case (_, ShortType) | (ShortType, _) + if Seq(first, second).forall(integer) => ShortType + case (_, DoubleType) | (DoubleType, _) => DoubleType + case (_, FloatType) | (FloatType, _) => FloatType + case _ => + throw SparkException.internalError( + s"Unexpected argument data types: ${min.dataType}, ${max.dataType}") + } + } + + private def integer(t: DataType): Boolean = t match { + case _: ShortType | _: IntegerType | _: LongType => true + case _ => false + } + + override def checkInputDataTypes(): TypeCheckResult = { + var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess + def requiredType = "integer or floating-point" + Seq((min, "min", 0), + (max, "max", 1), + (seedExpression, "seed", 2)).foreach { + case (expr: Expression, name: String, index: Int) => + if (result == TypeCheckResult.TypeCheckSuccess) { + if (!expr.foldable) { + result = DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> name, + "inputType" -> requiredType, + "inputExpr" -> toSQLExpr(expr))) + } else expr.dataType match { + case _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | + _: NullType => + case _ => + result = DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(index), + "requiredType" -> requiredType, + "inputSql" -> toSQLExpr(expr), + "inputType" -> toSQLType(expr.dataType))) + } + } + } + result + } + + override def first: Expression = min + override def second: Expression = max + override def third: Expression = seedExpression + + override def withNewSeed(newSeed: Long): Expression = + Uniform(min, max, Literal(newSeed, LongType)) + + override def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + Uniform(newFirst, newSecond, newThird) + + override def replacement: Expression = { + if (Seq(min, max, seedExpression).exists(_.dataType == NullType)) { + Literal(null) + } else { + def cast(e: Expression, to: DataType): Expression = if (e.dataType == to) e else Cast(e, to) + cast(Add( + cast(min, DoubleType), + Multiply( + Subtract( + cast(max, DoubleType), + cast(min, DoubleType)), + Rand(seed))), + dataType) + } + } +} + +@ExpressionDescription( + usage = """ + _FUNC_(length[, seed]) - Returns a string of the specified length whose characters are chosen + uniformly at random from the following pool of characters: 0-9, a-z, A-Z. The random seed is + optional. The string length must be a constant two-byte or four-byte integer (SMALLINT or INT, + respectively). + """, + examples = + """ + Examples: + > SELECT _FUNC_(3, 0) AS result; + ceV + """, + since = "4.0.0", + group = "string_funcs") +case class RandStr(length: Expression, override val seedExpression: Expression) + extends ExpressionWithRandomSeed with BinaryLike[Expression] with Nondeterministic { + def this(length: Expression) = this(length, UnresolvedSeed) + + override def nullable: Boolean = false + override def dataType: DataType = StringType + override def stateful: Boolean = true + override def left: Expression = length + override def right: Expression = seedExpression + + /** + * Record ID within each partition. By being transient, the Random Number Generator is + * reset every time we serialize and deserialize and initialize it. + */ + @transient protected var rng: XORShiftRandom = _ + + @transient protected lazy val seed: Long = seedExpression match { + case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int] + case e if e.dataType == LongType => e.eval().asInstanceOf[Long] + } + override protected def initializeInternal(partitionIndex: Int): Unit = { + rng = new XORShiftRandom(seed + partitionIndex) + } + + override def withNewSeed(newSeed: Long): Expression = RandStr(length, Literal(newSeed, LongType)) + override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression = + RandStr(newFirst, newSecond) + + override def checkInputDataTypes(): TypeCheckResult = { + var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess + def requiredType = "INT or SMALLINT" + Seq((length, "length", 0), + (seedExpression, "seedExpression", 1)).foreach { + case (expr: Expression, name: String, index: Int) => + if (result == TypeCheckResult.TypeCheckSuccess) { + if (!expr.foldable) { + result = DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> name, + "inputType" -> requiredType, + "inputExpr" -> toSQLExpr(expr))) + } else expr.dataType match { + case _: ShortType | _: IntegerType => + case _: LongType if index == 1 => + case _ => + result = DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(index), + "requiredType" -> requiredType, + "inputSql" -> toSQLExpr(expr), + "inputType" -> toSQLType(expr.dataType))) + } + } + } + result + } + + override def evalInternal(input: InternalRow): Any = { + val numChars = length.eval(input).asInstanceOf[Number].intValue() + val bytes = new Array[Byte](numChars) + (0 until numChars).foreach { i => + // We generate a random number between 0 and 61, inclusive. Between the 62 different choices + // we choose 0-9, a-z, or A-Z, where each category comprises 10 choices, 26 choices, or 26 + // choices, respectively (10 + 26 + 26 = 62). + val num = (rng.nextInt() % 62).abs + num match { + case _ if num < 10 => + bytes.update(i, ('0' + num).toByte) + case _ if num < 36 => + bytes.update(i, ('a' + num - 10).toByte) + case _ => + bytes.update(i, ('A' + num - 36).toByte) + } + } + val result: UTF8String = UTF8String.fromBytes(bytes.toArray) + result + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val className = classOf[XORShiftRandom].getName + val rngTerm = ctx.addMutableState(className, "rng") + ctx.addPartitionInitializationStatement( + s"$rngTerm = new $className(${seed}L + partitionIndex);") + val eval = length.genCode(ctx) + ev.copy(code = + code""" + |${eval.code} + |int length = (int)(${eval.value}); + |char[] chars = new char[length]; + |for (int i = 0; i < length; i++) { + | int v = Math.abs($rngTerm.nextInt() % 62); + | if (v < 10) { + | chars[i] = (char)('0' + v); + | } else if (v < 36) { + | chars[i] = (char)('a' + (v - 10)); + | } else { + | chars[i] = (char)('A' + (v - 36)); + | } + |} + |UTF8String ${ev.value} = UTF8String.fromString(new String(chars)); + |boolean ${ev.isNull} = false; + |""".stripMargin, + isNull = FalseLiteral) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 6ccd5a451eafc..da6d786efb4e3 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.nio.{ByteBuffer, CharBuffer} import java.nio.charset.CharacterCodingException -import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols} +import java.text.{DecimalFormat, DecimalFormatSymbols} import java.util.{Base64 => JBase64, HashMap, Locale, Map => JMap} import scala.collection.mutable.ArrayBuffer @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LO import org.apache.spark.sql.catalyst.util.{ArrayData, CharsetProvider, CollationFactory, CollationSupport, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation} +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeNonCSAICollation} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods @@ -609,6 +609,8 @@ case class Contains(left: Expression, right: Expression) extends StringPredicate defineCodeGen(ctx, ev, (c1, c2) => CollationSupport.Contains.genCode(c1, c2, collationId)) } + override def inputTypes : Seq[AbstractDataType] = + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Contains = copy(left = newLeft, right = newRight) } @@ -650,6 +652,10 @@ case class StartsWith(left: Expression, right: Expression) extends StringPredica defineCodeGen(ctx, ev, (c1, c2) => CollationSupport.StartsWith.genCode(c1, c2, collationId)) } + + override def inputTypes : Seq[AbstractDataType] = + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) + override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): StartsWith = copy(left = newLeft, right = newRight) } @@ -691,6 +697,10 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate defineCodeGen(ctx, ev, (c1, c2) => CollationSupport.EndsWith.genCode(c1, c2, collationId)) } + + override def inputTypes : Seq[AbstractDataType] = + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) + override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): EndsWith = copy(left = newLeft, right = newRight) } @@ -919,7 +929,7 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp override def dataType: DataType = srcExpr.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) override def first: Expression = srcExpr override def second: Expression = searchExpr override def third: Expression = replaceExpr @@ -1167,7 +1177,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac override def dataType: DataType = srcExpr.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) override def first: Expression = srcExpr override def second: Expression = matchingExpr override def third: Expression = replaceExpr @@ -1394,6 +1404,9 @@ case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None) override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = CollationSupport.StringTrim.exec(srcString, trimString, collationId) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy( srcStr = newChildren.head, @@ -1501,6 +1514,9 @@ case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = CollationSupport.StringTrimLeft.exec(srcString, trimString, collationId) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): StringTrimLeft = copy( @@ -1561,6 +1577,9 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = CollationSupport.StringTrimRight.exec(srcString, trimString, collationId) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): StringTrimRight = copy( @@ -1595,7 +1614,7 @@ case class StringInstr(str: Expression, substr: Expression) override def right: Expression = substr override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) override def nullSafeEval(string: Any, sub: Any): Any = { CollationSupport.StringInstr. @@ -1643,7 +1662,7 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: override def dataType: DataType = strExpr.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, IntegerType) override def first: Expression = strExpr override def second: Expression = delimExpr override def third: Expression = countExpr @@ -1701,7 +1720,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) override def nullable: Boolean = substr.nullable || str.nullable override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, IntegerType) override def eval(input: InternalRow): Any = { val s = start.eval(input) @@ -1969,13 +1988,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC if (pattern == null) { null } else { - val sb = new StringBuffer() - val formatter = new java.util.Formatter(sb, Locale.US) - + val formatter = new java.util.Formatter(Locale.US) val arglist = children.tail.map(_.eval(input).asInstanceOf[AnyRef]) - formatter.format(pattern.asInstanceOf[UTF8String].toString, arglist: _*) - - UTF8String.fromString(sb.toString) + UTF8String.fromString( + formatter.format(pattern.asInstanceOf[UTF8String].toString, arglist: _*).toString) } } @@ -2006,19 +2022,16 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val form = ctx.freshName("formatter") val formatter = classOf[java.util.Formatter].getName - val sb = ctx.freshName("sb") - val stringBuffer = classOf[StringBuffer].getName ev.copy(code = code""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { - $stringBuffer $sb = new $stringBuffer(); - $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); + $formatter $form = new $formatter(${classOf[Locale].getName}.US); Object[] $argList = new Object[$numArgLists]; $argListCodes - $form.format(${pattern.value}.toString(), $argList); - ${ev.value} = UTF8String.fromString($sb.toString()); + ${ev.value} = UTF8String.fromString( + $form.format(${pattern.value}.toString(), $argList).toString()); }""") } @@ -3333,14 +3346,37 @@ case class FormatNumber(x: Expression, d: Expression) /** * Splits a string into arrays of sentences, where each sentence is an array of words. - * The 'lang' and 'country' arguments are optional, and if omitted, the default locale is used. + * The `lang` and `country` arguments are optional, their default values are all '', + * - When they are omitted: + * 1. If they are both omitted, the `Locale.ROOT - locale(language='', country='')` is used. + * The `Locale.ROOT` is regarded as the base locale of all locales, and is used as the + * language/country neutral locale for the locale sensitive operations. + * 2. If the `country` is omitted, the `locale(language, country='')` is used. + * - When they are null: + * 1. If they are both `null`, the `Locale.US - locale(language='en', country='US')` is used. + * 2. If the `language` is null and the `country` is not null, + * the `Locale.US - locale(language='en', country='US')` is used. + * 3. If the `language` is not null and the `country` is null, the `locale(language)` is used. + * 4. If neither is `null`, the `locale(language, country)` is used. */ @ExpressionDescription( - usage = "_FUNC_(str[, lang, country]) - Splits `str` into an array of array of words.", + usage = "_FUNC_(str[, lang[, country]]) - Splits `str` into an array of array of words.", + arguments = """ + Arguments: + * str - A STRING expression to be parsed. + * lang - An optional STRING expression with a language code from ISO 639 Alpha-2 (e.g. 'DE'), + Alpha-3, or a language subtag of up to 8 characters. + * country - An optional STRING expression with a country code from ISO 3166 alpha-2 country + code or a UN M.49 numeric-3 area code. + """, examples = """ Examples: > SELECT _FUNC_('Hi there! Good morning.'); [["Hi","there"],["Good","morning"]] + > SELECT _FUNC_('Hi there! Good morning.', 'en'); + [["Hi","there"],["Good","morning"]] + > SELECT _FUNC_('Hi there! Good morning.', 'en', 'US'); + [["Hi","there"],["Good","morning"]] """, since = "2.0.0", group = "string_funcs") @@ -3348,7 +3384,9 @@ case class Sentences( str: Expression, language: Expression = Literal(""), country: Expression = Literal("")) - extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback { + extends TernaryExpression + with ImplicitCastInputTypes + with RuntimeReplaceable { def this(str: Expression) = this(str, Literal(""), Literal("")) def this(str: Expression, language: Expression) = this(str, language, Literal("")) @@ -3362,49 +3400,18 @@ case class Sentences( override def second: Expression = language override def third: Expression = country - override def eval(input: InternalRow): Any = { - val string = str.eval(input) - if (string == null) { - null - } else { - val languageStr = language.eval(input).asInstanceOf[UTF8String] - val countryStr = country.eval(input).asInstanceOf[UTF8String] - val locale = if (languageStr != null && countryStr != null) { - new Locale(languageStr.toString, countryStr.toString) - } else { - Locale.US - } - getSentences(string.asInstanceOf[UTF8String].toString, locale) - } - } - - private def getSentences(sentences: String, locale: Locale) = { - val bi = BreakIterator.getSentenceInstance(locale) - bi.setText(sentences) - var idx = 0 - val result = new ArrayBuffer[GenericArrayData] - while (bi.next != BreakIterator.DONE) { - val sentence = sentences.substring(idx, bi.current) - idx = bi.current - - val wi = BreakIterator.getWordInstance(locale) - var widx = 0 - wi.setText(sentence) - val words = new ArrayBuffer[UTF8String] - while (wi.next != BreakIterator.DONE) { - val word = sentence.substring(widx, wi.current) - widx = wi.current - if (Character.isLetterOrDigit(word.charAt(0))) words += UTF8String.fromString(word) - } - result += new GenericArrayData(words) - } - new GenericArrayData(result) - } + override def replacement: Expression = + StaticInvoke( + classOf[ExpressionImplUtils], + dataType, + "getSentences", + Seq(str, language, country), + inputTypes, + propagateNull = false) override protected def withNewChildrenInternal( newFirst: Expression, newSecond: Expression, newThird: Expression): Sentences = copy(str = newFirst, language = newSecond, country = newThird) - } /** @@ -3475,7 +3482,7 @@ case class SplitPart ( false) override def nodeName: String = "split_part" override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, IntegerType) def children: Seq[Expression] = Seq(str, delimiter, partNum) protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = { copy(str = newChildren.apply(0), delimiter = newChildren.apply(1), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala index a34a2c740876e..ad9610ea0c78a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.types.{MapType, NullType, StringType} +import org.apache.spark.sql.types.{BinaryType, MapType, NullType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -88,6 +88,27 @@ case class FromProtobuf( messageName: Expression, descFilePath: Expression, options: Expression) extends QuaternaryExpression with RuntimeReplaceable { + + def this(data: Expression, messageName: Expression, descFilePathOrOptions: Expression) = { + this( + data, + messageName, + descFilePathOrOptions match { + case lit: Literal + if lit.dataType == StringType || lit.dataType == BinaryType => descFilePathOrOptions + case _ => Literal(null) + }, + descFilePathOrOptions.dataType match { + case _: MapType => descFilePathOrOptions + case _ => Literal(null) + } + ) + } + + def this(data: Expression, messageName: Expression) = { + this(data, messageName, Literal(null), Literal(null)) + } + override def first: Expression = data override def second: Expression = messageName override def third: Expression = descFilePath @@ -110,11 +131,11 @@ case class FromProtobuf( "representing the Protobuf message name")) } val descFilePathCheck = descFilePath.dataType match { - case _: StringType if descFilePath.foldable => None + case _: StringType | BinaryType | NullType if descFilePath.foldable => None case _ => Some(TypeCheckResult.TypeCheckFailure( "The third argument of the FROM_PROTOBUF SQL function must be a constant string " + - "representing the Protobuf descriptor file path")) + "or binary data representing the Protobuf descriptor file path")) } val optionsCheck = options.dataType match { case MapType(StringType, StringType, _) | @@ -141,7 +162,10 @@ case class FromProtobuf( s.toString } val descFilePathValue: Option[Array[Byte]] = descFilePath.eval() match { + case s: UTF8String if s.toString.isEmpty => None case s: UTF8String => Some(ProtobufHelper.readDescriptorFileContent(s.toString)) + case bytes: Array[Byte] if bytes.isEmpty => None + case bytes: Array[Byte] => Some(bytes) case null => None } val optionsValue: Map[String, String] = options.eval() match { @@ -201,6 +225,27 @@ case class ToProtobuf( messageName: Expression, descFilePath: Expression, options: Expression) extends QuaternaryExpression with RuntimeReplaceable { + + def this(data: Expression, messageName: Expression, descFilePathOrOptions: Expression) = { + this( + data, + messageName, + descFilePathOrOptions match { + case lit: Literal + if lit.dataType == StringType || lit.dataType == BinaryType => descFilePathOrOptions + case _ => Literal(null) + }, + descFilePathOrOptions.dataType match { + case _: MapType => descFilePathOrOptions + case _ => Literal(null) + } + ) + } + + def this(data: Expression, messageName: Expression) = { + this(data, messageName, Literal(null), Literal(null)) + } + override def first: Expression = data override def second: Expression = messageName override def third: Expression = descFilePath @@ -223,11 +268,11 @@ case class ToProtobuf( "representing the Protobuf message name")) } val descFilePathCheck = descFilePath.dataType match { - case _: StringType if descFilePath.foldable => None + case _: StringType | BinaryType | NullType if descFilePath.foldable => None case _ => Some(TypeCheckResult.TypeCheckFailure( "The third argument of the TO_PROTOBUF SQL function must be a constant string " + - "representing the Protobuf descriptor file path")) + "or binary data representing the Protobuf descriptor file path")) } val optionsCheck = options.dataType match { case MapType(StringType, StringType, _) | @@ -256,6 +301,7 @@ case class ToProtobuf( } val descFilePathValue: Option[Array[Byte]] = descFilePath.eval() match { case s: UTF8String => Some(ProtobufHelper.readDescriptorFileContent(s.toString)) + case bytes: Array[Byte] => Some(bytes) case null => None } val optionsValue: Map[String, String] = options.eval() match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala index bbf554d384b12..487985b4770ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala @@ -126,7 +126,7 @@ object VariantExpressionEvalUtils { buildVariant(builder, element, elementType) } builder.finishWritingArray(start, offsets) - case MapType(StringType, valueType, _) => + case MapType(_: StringType, valueType, _) => val data = input.asInstanceOf[MapData] val keys = data.keyArray() val values = data.valueArray() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index bd956fa5c00e1..2c8ca1e8bb2bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.json.JsonInferSchema import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, VARIANT_GET} import org.apache.spark.sql.catalyst.trees.UnaryLike -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, QuotingUtils} import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf @@ -117,6 +117,73 @@ case class IsVariantNull(child: Expression) extends UnaryExpression copy(child = newChild) } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr) - Convert a nested input (array/map/struct) into a variant where maps and structs are converted to variant objects which are unordered unlike SQL structs. Input maps can only have string keys.", + examples = """ + Examples: + > SELECT _FUNC_(named_struct('a', 1, 'b', 2)); + {"a":1,"b":2} + > SELECT _FUNC_(array(1, 2, 3)); + [1,2,3] + > SELECT _FUNC_(array(named_struct('a', 1))); + [{"a":1}] + > SELECT _FUNC_(array(map("a", 2))); + [{"a":2}] + """, + since = "4.0.0", + group = "variant_funcs") +// scalastyle:on line.size.limit +case class ToVariantObject(child: Expression) + extends UnaryExpression + with NullIntolerant + with QueryErrorsBase { + + override val dataType: DataType = VariantType + + // Only accept nested types at the root but any types can be nested inside. + override def checkInputDataTypes(): TypeCheckResult = { + val checkResult: Boolean = child.dataType match { + case _: StructType | _: ArrayType | _: MapType => + VariantGet.checkDataType(child.dataType, allowStructsAndMaps = true) + case _ => false + } + if (!checkResult) { + DataTypeMismatch( + errorSubClass = "CAST_WITHOUT_SUGGESTION", + messageParameters = + Map("srcType" -> toSQLType(child.dataType), "targetType" -> toSQLType(VariantType))) + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def prettyName: String = "to_variant_object" + + override protected def withNewChildInternal(newChild: Expression): ToVariantObject = + copy(child = newChild) + + protected override def nullSafeEval(input: Any): Any = + VariantExpressionEvalUtils.castToVariant(input, child.dataType) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childCode = child.genCode(ctx) + val cls = variant.VariantExpressionEvalUtils.getClass.getName.stripSuffix("$") + val fromArg = ctx.addReferenceObj("from", child.dataType) + val javaType = JavaCode.javaType(VariantType) + val code = + code""" + ${childCode.code} + boolean ${ev.isNull} = ${childCode.isNull}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(VariantType)}; + if (!${childCode.isNull}) { + ${ev.value} = $cls.castToVariant(${childCode.value}, $fromArg); + } + """ + ev.copy(code = code) + } +} + object VariantPathParser extends RegexParsers { // A path segment in the `VariantGet` expression represents either an object key access or an // array index access. @@ -260,13 +327,16 @@ case object VariantGet { * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset * of them. For nested types, we reject map types with a non-string key type. */ - def checkDataType(dataType: DataType): Boolean = dataType match { + def checkDataType(dataType: DataType, allowStructsAndMaps: Boolean = true): Boolean = + dataType match { case _: NumericType | BooleanType | _: StringType | BinaryType | _: DatetimeType | VariantType | _: DayTimeIntervalType | _: YearMonthIntervalType => true - case ArrayType(elementType, _) => checkDataType(elementType) - case MapType(_: StringType, valueType, _) => checkDataType(valueType) - case StructType(fields) => fields.forall(f => checkDataType(f.dataType)) + case ArrayType(elementType, _) => checkDataType(elementType, allowStructsAndMaps) + case MapType(_: StringType, valueType, _) if allowStructsAndMaps => + checkDataType(valueType, allowStructsAndMaps) + case StructType(fields) if allowStructsAndMaps => + fields.forall(f => checkDataType(f.dataType, allowStructsAndMaps)) case _ => false } @@ -635,7 +705,7 @@ object VariantExplode { > SELECT _FUNC_(parse_json('null')); VOID > SELECT _FUNC_(parse_json('[{"b":true,"a":0}]')); - ARRAY> + ARRAY> """, since = "4.0.0", group = "variant_funcs" @@ -666,7 +736,24 @@ object SchemaOfVariant { /** The actual implementation of the `SchemaOfVariant` expression. */ def schemaOfVariant(input: VariantVal): UTF8String = { val v = new Variant(input.getValue, input.getMetadata) - UTF8String.fromString(schemaOf(v).sql) + UTF8String.fromString(printSchema(schemaOf(v))) + } + + /** + * Similar to `dataType.sql`. The only difference is that `StructType` is shown as + * `OBJECT<...>` rather than `STRUCT<...>`. + * SchemaOfVariant expressions use the Struct DataType to denote the Object type in the variant + * spec. However, the Object type is not equivalent to the struct type as an Object represents an + * unordered bag of key-value pairs while the Struct type is ordered. + */ + def printSchema(dataType: DataType): String = dataType match { + case StructType(fields) => + def printField(f: StructField): String = + s"${QuotingUtils.quoteIfNeeded(f.name)}: ${printSchema(f.dataType)}" + + s"OBJECT<${fields.map(printField).mkString(", ")}>" + case ArrayType(elementType, _) => s"ARRAY<${printSchema(elementType)}>" + case _ => dataType.sql } /** @@ -731,7 +818,7 @@ object SchemaOfVariant { > SELECT _FUNC_(parse_json(j)) FROM VALUES ('1'), ('2'), ('3') AS tab(j); BIGINT > SELECT _FUNC_(parse_json(j)) FROM VALUES ('{"a": 1}'), ('{"b": true}'), ('{"c": 1.23}') AS tab(j); - STRUCT + OBJECT """, since = "4.0.0", group = "variant_funcs") @@ -767,7 +854,8 @@ case class SchemaOfVariantAgg( override def merge(buffer: DataType, input: DataType): DataType = SchemaOfVariant.mergeSchema(buffer, input) - override def eval(buffer: DataType): Any = UTF8String.fromString(buffer.sql) + override def eval(buffer: DataType): Any = + UTF8String.fromString(SchemaOfVariant.printSchema(buffer)) override def serialize(buffer: DataType): Array[Byte] = buffer.json.getBytes("UTF-8") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index 5b06741a2f54e..31e65cf0abc95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -255,7 +255,7 @@ case class XPathList(xml: Expression, path: Expression) extends XPathExtract { override def nullSafeEval(xml: Any, path: Any): Any = { val nodeList = xpathUtil.evalNodeList(xml.asInstanceOf[UTF8String].toString, pathString) if (nodeList ne null) { - val ret = new Array[UTF8String](nodeList.getLength) + val ret = new Array[AnyRef](nodeList.getLength) var i = 0 while (i < nodeList.getLength) { ret(i) = UTF8String.fromString(nodeList.item(i).getNodeValue) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 26de4cc7ad1c8..13129d44fe0c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -215,7 +215,7 @@ class JacksonParser( ) } - Some(InternalRow(new GenericArrayData(res.toArray))) + Some(InternalRow(new GenericArrayData(res.toArray[Any]))) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala index 424f4b96271d3..6c0d7189862d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala @@ -1064,7 +1064,7 @@ object DecorrelateInnerQuery extends PredicateHelper { // Project, they could get added at the beginning or the end of the output columns // depending on the child plan. // The inner expressions for the domain are the values of newOuterReferenceMap. - val domainProjections = collectedChildOuterReferences.map(newOuterReferenceMap(_)) + val domainProjections = newOuterReferences.map(newOuterReferenceMap(_)) val newChild = Project(child.output ++ domainProjections, decorrelatedChild) (newChild, newJoinCond, newOuterReferenceMap) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a2a26924885c0..8e14537c6a5b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -303,7 +303,8 @@ abstract class Optimizer(catalogManager: CatalogManager) ReplaceCurrentLike(catalogManager), SpecialDatetimeValues, RewriteAsOfJoin, - EvalInlineTables + EvalInlineTables, + ReplaceTranspose ) override def apply(plan: LogicalPlan): LogicalPlan = { @@ -1722,15 +1723,18 @@ object EliminateSorts extends Rule[LogicalPlan] { * 3) by eliminating the always-true conditions given the constraints on the child's output. */ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { + private def shouldApply(child: LogicalPlan): Boolean = + SQLConf.get.getConf(SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN) || !child.isStreaming + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsPattern(FILTER), ruleId) { // If the filter condition always evaluate to true, remove the filter. case Filter(Literal(true, BooleanType), child) => child // If the filter condition always evaluate to null or false, // replace the input with an empty relation. - case Filter(Literal(null, _), child) => + case Filter(Literal(null, _), child) if shouldApply(child) => LocalRelation(child.output, data = Seq.empty, isStreaming = child.isStreaming) - case Filter(Literal(false, BooleanType), child) => + case Filter(Literal(false, BooleanType), child) if shouldApply(child) => LocalRelation(child.output, data = Seq.empty, isStreaming = child.isStreaming) // If any deterministic condition is guaranteed to be true given the constraints on the child's // output, remove the condition diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 801bd2693af42..5aef82b64ed32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -400,13 +400,14 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { (distinctAggOperatorMap.flatMap(_._2) ++ regularAggOperatorMap.map(e => (e._1, e._3))).toMap + val groupByMapNonFoldable = groupByMap.filter(!_._1.foldable) val patchedAggExpressions = a.aggregateExpressions.map { e => e.transformDown { case e: Expression => // The same GROUP BY clauses can have different forms (different names for instance) in // the groupBy and aggregate expressions of an aggregate. This makes a map lookup // tricky. So we do a linear search for a semantically equal group by expression. - groupByMap + groupByMapNonFoldable .find(ge => e.semanticEquals(ge._1)) .map(_._2) .getOrElse(transformations.getOrElse(e, e)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index d0ee9f2d110d5..3cdde622d51f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -79,7 +79,7 @@ object ConstantFolding extends Rule[LogicalPlan] { // Fold expressions that are foldable. case e if e.foldable => try { - Literal.create(e.eval(EmptyRow), e.dataType) + Literal.create(e.freshCopyIfContainsStatefulExpression().eval(EmptyRow), e.dataType) } catch { case NonFatal(_) if isConditionalBranch => // When doing constant folding inside conditional expressions, we should not fail diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 55f222d880844..a524acc19aea8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -184,3 +184,10 @@ object SpecialDatetimeValues extends Rule[LogicalPlan] { } } } + +object ReplaceTranspose extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.transform { + case t @ Transpose(output, data) => + LocalRelation(output, data) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index b0922542c5629..674005caaf1b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer, Set} import scala.jdk.CollectionConverters._ import scala.util.{Left, Right} -import org.antlr.v4.runtime.{ParserRuleContext, Token} +import org.antlr.v4.runtime.{ParserRuleContext, RuleContext, Token} import org.antlr.v4.runtime.misc.Interval import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} @@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.PARAMETER import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone} -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog, TableWritePrivilege} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, IdentityColumnSpec, SupportsNamespaces, TableCatalog, TableWritePrivilege} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors, QueryParsingErrors, SqlScriptingErrors} @@ -173,14 +173,10 @@ class AstBuilder extends DataTypeAstBuilder case Some(c: CreateVariable) => if (allowVarDeclare) { throw SqlScriptingErrors.variableDeclarationOnlyAtBeginning( - c.origin, - toSQLId(c.name.asInstanceOf[UnresolvedIdentifier].nameParts), - c.origin.line.get.toString) + c.origin, c.name.asInstanceOf[UnresolvedIdentifier].nameParts) } else { throw SqlScriptingErrors.variableDeclarationNotAllowedInScope( - c.origin, - toSQLId(c.name.asInstanceOf[UnresolvedIdentifier].nameParts), - c.origin.line.get.toString) + c.origin, c.name.asInstanceOf[UnresolvedIdentifier].nameParts) } case _ => } @@ -200,7 +196,9 @@ class AstBuilder extends DataTypeAstBuilder el.multipartIdentifier().getText.toLowerCase(Locale.ROOT) => withOrigin(bl) { throw SqlScriptingErrors.labelsMismatch( - CurrentOrigin.get, bl.multipartIdentifier().getText, el.multipartIdentifier().getText) + CurrentOrigin.get, + bl.multipartIdentifier().getText, + el.multipartIdentifier().getText) } case (None, Some(el: EndLabelContext)) => withOrigin(el) { @@ -261,6 +259,120 @@ class AstBuilder extends DataTypeAstBuilder WhileStatement(condition, body, Some(labelText)) } + override def visitSearchedCaseStatement(ctx: SearchedCaseStatementContext): CaseStatement = { + val conditions = ctx.conditions.asScala.toList.map(boolExpr => withOrigin(boolExpr) { + SingleStatement( + Project( + Seq(Alias(expression(boolExpr), "condition")()), + OneRowRelation())) + }) + val conditionalBodies = + ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body)) + + if (conditions.length != conditionalBodies.length) { + throw SparkException.internalError( + s"Mismatched number of conditions ${conditions.length} and condition bodies" + + s" ${conditionalBodies.length} in case statement") + } + + CaseStatement( + conditions = conditions, + conditionalBodies = conditionalBodies, + elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body))) + } + + override def visitSimpleCaseStatement(ctx: SimpleCaseStatementContext): CaseStatement = { + // uses EqualTo to compare the case variable(the main case expression) + // to the WHEN clause expressions + val conditions = ctx.conditionExpressions.asScala.toList.map(expr => withOrigin(expr) { + SingleStatement( + Project( + Seq(Alias(EqualTo(expression(ctx.caseVariable), expression(expr)), "condition")()), + OneRowRelation())) + }) + val conditionalBodies = + ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body)) + + if (conditions.length != conditionalBodies.length) { + throw SparkException.internalError( + s"Mismatched number of conditions ${conditions.length} and condition bodies" + + s" ${conditionalBodies.length} in case statement") + } + + CaseStatement( + conditions = conditions, + conditionalBodies = conditionalBodies, + elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body))) + } + + override def visitRepeatStatement(ctx: RepeatStatementContext): RepeatStatement = { + val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel())) + val boolExpr = ctx.booleanExpression() + + val condition = withOrigin(boolExpr) { + SingleStatement( + Project( + Seq(Alias(expression(boolExpr), "condition")()), + OneRowRelation()))} + val body = visitCompoundBody(ctx.compoundBody()) + + RepeatStatement(condition, body, Some(labelText)) + } + + private def leaveOrIterateContextHasLabel( + ctx: RuleContext, label: String, isIterate: Boolean): Boolean = { + ctx match { + case c: BeginEndCompoundBlockContext + if Option(c.beginLabel()).isDefined && + c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) => + if (isIterate) { + throw SqlScriptingErrors.invalidIterateLabelUsageForCompound(CurrentOrigin.get, label) + } + true + case c: WhileStatementContext + if Option(c.beginLabel()).isDefined && + c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + => true + case c: RepeatStatementContext + if Option(c.beginLabel()).isDefined && + c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + => true + case _ => false + } + } + + override def visitLeaveStatement(ctx: LeaveStatementContext): LeaveStatement = + withOrigin(ctx) { + val labelText = ctx.multipartIdentifier().getText.toLowerCase(Locale.ROOT) + var parentCtx = ctx.parent + + while (Option(parentCtx).isDefined) { + if (leaveOrIterateContextHasLabel(parentCtx, labelText, isIterate = false)) { + return LeaveStatement(labelText) + } + parentCtx = parentCtx.parent + } + + throw SqlScriptingErrors.labelDoesNotExist( + CurrentOrigin.get, labelText, "LEAVE") + } + + override def visitIterateStatement(ctx: IterateStatementContext): IterateStatement = + withOrigin(ctx) { + val labelText = ctx.multipartIdentifier().getText.toLowerCase(Locale.ROOT) + var parentCtx = ctx.parent + + while (Option(parentCtx).isDefined) { + if (leaveOrIterateContextHasLabel(parentCtx, labelText, isIterate = true)) { + return IterateStatement(labelText) + } + parentCtx = parentCtx.parent + } + + throw SqlScriptingErrors.labelDoesNotExist( + CurrentOrigin.get, labelText, "ITERATE") + } + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { Option(ctx.statement().asInstanceOf[ParserRuleContext]) .orElse(Option(ctx.setResetStatement().asInstanceOf[ParserRuleContext])) @@ -357,7 +469,8 @@ class AstBuilder extends DataTypeAstBuilder ctx.aggregationClause, ctx.havingClause, ctx.windowClause, - plan + plan, + isPipeOperatorSelect = false ) } } @@ -945,7 +1058,8 @@ class AstBuilder extends DataTypeAstBuilder ctx.aggregationClause, ctx.havingClause, ctx.windowClause, - from + from, + isPipeOperatorSelect = false ) } @@ -1032,7 +1146,8 @@ class AstBuilder extends DataTypeAstBuilder aggregationClause, havingClause, windowClause, - isDistinct = false) + isDistinct = false, + isPipeOperatorSelect = false) ScriptTransformation( string(visitStringLit(transformClause.script)), @@ -1053,6 +1168,8 @@ class AstBuilder extends DataTypeAstBuilder * Add a regular (SELECT) query specification to a logical plan. The query specification * is the core of the logical plan, this is where sourcing (FROM clause), projection (SELECT), * aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place. + * If 'isPipeOperatorSelect' is true, wraps each projected expression with a [[PipeSelect]] + * expression for future validation of the expressions during analysis. * * Note that query hints are ignored (both by the parser and the builder). */ @@ -1064,7 +1181,8 @@ class AstBuilder extends DataTypeAstBuilder aggregationClause: AggregationClauseContext, havingClause: HavingClauseContext, windowClause: WindowClauseContext, - relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { + relation: LogicalPlan, + isPipeOperatorSelect: Boolean): LogicalPlan = withOrigin(ctx) { val isDistinct = selectClause.setQuantifier() != null && selectClause.setQuantifier().DISTINCT() != null @@ -1076,7 +1194,8 @@ class AstBuilder extends DataTypeAstBuilder aggregationClause, havingClause, windowClause, - isDistinct) + isDistinct, + isPipeOperatorSelect) // Hint selectClause.hints.asScala.foldRight(plan)(withHints) @@ -1090,7 +1209,8 @@ class AstBuilder extends DataTypeAstBuilder aggregationClause: AggregationClauseContext, havingClause: HavingClauseContext, windowClause: WindowClauseContext, - isDistinct: Boolean): LogicalPlan = { + isDistinct: Boolean, + isPipeOperatorSelect: Boolean): LogicalPlan = { // Add lateral views. val withLateralView = lateralView.asScala.foldLeft(relation)(withGenerate) @@ -1104,7 +1224,20 @@ class AstBuilder extends DataTypeAstBuilder } def createProject() = if (namedExpressions.nonEmpty) { - Project(namedExpressions, withFilter) + val newProjectList: Seq[NamedExpression] = if (isPipeOperatorSelect) { + // If this is a pipe operator |> SELECT clause, add a [[PipeSelect]] expression wrapping + // each alias in the project list, so the analyzer can check invariants later. + namedExpressions.map { + case a: Alias => + a.withNewChildren(Seq(PipeSelect(a.child))) + .asInstanceOf[NamedExpression] + case other => + other + } + } else { + namedExpressions + } + Project(newProjectList, withFilter) } else { withFilter } @@ -3398,6 +3531,14 @@ class AstBuilder extends DataTypeAstBuilder throw QueryParsingErrors.fromToIntervalUnsupportedError(from, to, ctx) } } catch { + // Keep error class of SparkIllegalArgumentExceptions and enrich it with query context + case se: SparkIllegalArgumentException => + val pe = new ParseException( + errorClass = se.getErrorClass, + messageParameters = se.getMessageParameters.asScala.toMap, + ctx) + pe.setStackTrace(se.getStackTrace) + throw pe // Handle Exceptions thrown by CalendarInterval case e: IllegalArgumentException => val pe = new ParseException( @@ -3478,13 +3619,19 @@ class AstBuilder extends DataTypeAstBuilder } } + val dataType = typedVisit[DataType](ctx.dataType) ColumnDefinition( name = name, - dataType = typedVisit[DataType](ctx.dataType), + dataType = dataType, nullable = nullable, comment = commentSpec.map(visitCommentSpec), defaultValue = defaultExpression.map(visitDefaultExpression), - generationExpression = generationExpression.map(visitGenerationExpression) + generationExpression = generationExpression.collect { + case ctx: GeneratedColumnContext => visitGeneratedColumn(ctx) + }, + identityColumnSpec = generationExpression.collect { + case ctx: IdentityColumnContext => visitIdentityColumn(ctx, dataType) + } ) } @@ -3540,11 +3687,63 @@ class AstBuilder extends DataTypeAstBuilder /** * Create a generation expression string. */ - override def visitGenerationExpression(ctx: GenerationExpressionContext): String = + override def visitGeneratedColumn(ctx: GeneratedColumnContext): String = withOrigin(ctx) { getDefaultExpression(ctx.expression(), "GENERATED").originalSQL } + /** + * Parse and verify IDENTITY column definition. + * + * @param ctx The parser context. + * @param dataType The data type of column defined as IDENTITY column. Used for verification. + * @return Tuple containing start, step and allowExplicitInsert. + */ + protected def visitIdentityColumn( + ctx: IdentityColumnContext, + dataType: DataType): IdentityColumnSpec = { + if (dataType != LongType && dataType != IntegerType) { + throw QueryParsingErrors.identityColumnUnsupportedDataType(ctx, dataType.toString) + } + // We support two flavors of syntax: + // (1) GENERATED ALWAYS AS IDENTITY (...) + // (2) GENERATED BY DEFAULT AS IDENTITY (...) + // (1) forbids explicit inserts, while (2) allows. + val allowExplicitInsert = ctx.BY() != null && ctx.DEFAULT() != null + val (start, step) = visitIdentityColSpec(ctx.identityColSpec()) + + new IdentityColumnSpec(start, step, allowExplicitInsert) + } + + override def visitIdentityColSpec(ctx: IdentityColSpecContext): (Long, Long) = { + val defaultStart = 1 + val defaultStep = 1 + if (ctx == null) { + return (defaultStart, defaultStep) + } + var (start, step): (Option[Long], Option[Long]) = (None, None) + ctx.sequenceGeneratorOption().asScala.foreach { option => + if (option.start != null) { + if (start.isDefined) { + throw QueryParsingErrors.identityColumnDuplicatedSequenceGeneratorOption(ctx, "START") + } + start = Some(option.start.getText.toLong) + } else if (option.step != null) { + if (step.isDefined) { + throw QueryParsingErrors.identityColumnDuplicatedSequenceGeneratorOption(ctx, "STEP") + } + step = Some(option.step.getText.toLong) + if (step.get == 0L) { + throw QueryParsingErrors.identityColumnIllegalStep(ctx) + } + } else { + throw SparkException + .internalError(s"Invalid identity column sequence generator option: ${option.getText}") + } + } + (start.getOrElse(defaultStart), step.getOrElse(defaultStep)) + } + /** * Create an optional comment string. */ @@ -5031,6 +5230,13 @@ class AstBuilder extends DataTypeAstBuilder val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) val isLazy = ctx.LAZY != null if (query.isDefined) { + // Disallow parameter markers in the query of the cache. + // We need this limitation because we store the original query text, pre substitution. + // To lift this we would need to reconstitute the query with parameter markers replaced with + // the values given at CACHE TABLE time, or we would need to store the parameter values + // alongside the text. + // The same rule can be found in CREATE VIEW builder. + checkInvalidParameter(query.get, "the query of CACHE TABLE") CacheTableAsSelect(ident.head, query.get, source(ctx.query()), isLazy, options) } else { CacheTable( @@ -5491,6 +5697,28 @@ class AstBuilder extends DataTypeAstBuilder ctx.EXISTS != null) } + /** + * Creates a plan for invoking a procedure. + * + * For example: + * {{{ + * CALL multi_part_name(v1, v2, ...); + * CALL multi_part_name(v1, param2 => v2, ...); + * CALL multi_part_name(param1 => v1, param2 => v2, ...); + * }}} + */ + override def visitCall(ctx: CallContext): LogicalPlan = withOrigin(ctx) { + val procedure = withIdentClause(ctx.identifierReference, UnresolvedProcedure) + val args = ctx.functionArgument.asScala.map { + case expr if expr.namedArgumentExpression != null => + val namedExpr = expr.namedArgumentExpression + NamedArgumentExpression(namedExpr.key.getText, expression(namedExpr.value)) + case expr => + expression(expr) + }.toSeq + Call(procedure, args) + } + /** * Create a TimestampAdd expression. */ @@ -5627,4 +5855,59 @@ class AstBuilder extends DataTypeAstBuilder withOrigin(ctx) { visitSetVariableImpl(ctx.query(), ctx.multipartIdentifierList(), ctx.assignmentList()) } + + override def visitOperatorPipeStatement(ctx: OperatorPipeStatementContext): LogicalPlan = { + visitOperatorPipeRightSide(ctx.operatorPipeRightSide(), plan(ctx.left)) + } + + private def visitOperatorPipeRightSide( + ctx: OperatorPipeRightSideContext, left: LogicalPlan): LogicalPlan = { + if (!SQLConf.get.getConf(SQLConf.OPERATOR_PIPE_SYNTAX_ENABLED)) { + operationNotAllowed("Operator pipe SQL syntax using |>", ctx) + } + Option(ctx.selectClause).map { c => + withSelectQuerySpecification( + ctx = ctx, + selectClause = c, + lateralView = new java.util.ArrayList[LateralViewContext](), + whereClause = null, + aggregationClause = null, + havingClause = null, + windowClause = null, + relation = left, + isPipeOperatorSelect = true) + }.getOrElse(Option(ctx.whereClause).map { c => + // Add a table subquery boundary between the new filter and the input plan if one does not + // already exist. This helps the analyzer behave as if we had added the WHERE clause after a + // table subquery containing the input plan. + val withSubqueryAlias = left match { + case s: SubqueryAlias => + s + case u: UnresolvedRelation => + u + case _ => + SubqueryAlias(SubqueryAlias.generateSubqueryName(), left) + } + withWhereClause(c, withSubqueryAlias) + }.get) + } + + /** + * Check plan for any parameters. + * If it finds any throws UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT. + * This method is used to ban parameters in some contexts. + */ + protected def checkInvalidParameter(plan: LogicalPlan, statement: String): Unit = { + plan.foreach { p => + p.expressions.foreach { expr => + if (expr.containsPattern(PARAMETER)) { + throw QueryParsingErrors.parameterMarkerNotAllowed(statement, p.origin) + } + } + } + plan.children.foreach(p => checkInvalidParameter(p, statement)) + plan.innerChildren.collect { + case child: LogicalPlan => checkInvalidParameter(child, statement) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala index 4a5259f09a8a3..ed40a5fd734b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala @@ -89,3 +89,52 @@ case class WhileStatement( condition: SingleStatement, body: CompoundBody, label: Option[String]) extends CompoundPlanStatement + +/** + * Logical operator for REPEAT statement. + * @param condition Any expression evaluating to a Boolean. + * Body is executed as long as the condition evaluates to false + * @param body Compound body is a collection of statements that are executed once no matter what, + * and then as long as condition is false. + * @param label An optional label for the loop which is unique amongst all labels for statements + * within which the LOOP statement is contained. + * If an end label is specified it must match the beginning label. + * The label can be used to LEAVE or ITERATE the loop. + */ +case class RepeatStatement( + condition: SingleStatement, + body: CompoundBody, + label: Option[String]) extends CompoundPlanStatement + + +/** + * Logical operator for LEAVE statement. + * The statement can be used both for compounds or any kind of loops. + * When used, the corresponding body/loop execution is skipped and the execution continues + * with the next statement after the body/loop. + * @param label Label of the compound or loop to leave. + */ +case class LeaveStatement(label: String) extends CompoundPlanStatement + +/** + * Logical operator for ITERATE statement. + * The statement can be used only for loops. + * When used, the rest of the loop is skipped and the loop execution continues + * with the next iteration. + * @param label Label of the loop to iterate. + */ +case class IterateStatement(label: String) extends CompoundPlanStatement + +/** + * Logical operator for CASE statement. + * @param conditions Collection of conditions which correspond to WHEN clauses. + * @param conditionalBodies Collection of bodies that have a corresponding condition, + * in WHEN branches. + * @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch. + */ +case class CaseStatement( + conditions: Seq[SingleStatement], + conditionalBodies: Seq[CompoundBody], + elseBody: Option[CompoundBody]) extends CompoundPlanStatement { + assert(conditions.length == conditionalBodies.length) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ColumnDefinition.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ColumnDefinition.scala index 83e50aa33c70d..043214711ccf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ColumnDefinition.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ColumnDefinition.scala @@ -21,10 +21,10 @@ import org.apache.spark.SparkException import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, UnaryExpression, Unevaluable} import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.util.GeneratedColumn +import org.apache.spark.sql.catalyst.util.{GeneratedColumn, IdentityColumn} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.validateDefaultValueExpr import org.apache.spark.sql.catalyst.util.ResolveDefaultColumnsUtils.{CURRENT_DEFAULT_COLUMN_METADATA_KEY, EXISTS_DEFAULT_COLUMN_METADATA_KEY} -import org.apache.spark.sql.connector.catalog.{Column => V2Column, ColumnDefaultValue} +import org.apache.spark.sql.connector.catalog.{Column => V2Column, ColumnDefaultValue, IdentityColumnSpec} import org.apache.spark.sql.connector.expressions.LiteralValue import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.connector.ColumnImpl @@ -41,7 +41,11 @@ case class ColumnDefinition( comment: Option[String] = None, defaultValue: Option[DefaultValueExpression] = None, generationExpression: Option[String] = None, + identityColumnSpec: Option[IdentityColumnSpec] = None, metadata: Metadata = Metadata.empty) extends Expression with Unevaluable { + assert( + generationExpression.isEmpty || identityColumnSpec.isEmpty, + "A ColumnDefinition cannot contain both a generation expression and an identity column spec.") override def children: Seq[Expression] = defaultValue.toSeq @@ -58,6 +62,7 @@ case class ColumnDefinition( comment.orNull, defaultValue.map(_.toV2(statement, name)).orNull, generationExpression.orNull, + identityColumnSpec.orNull, if (metadata == Metadata.empty) null else metadata.json) } @@ -75,8 +80,19 @@ case class ColumnDefinition( generationExpression.foreach { generationExpr => metadataBuilder.putString(GeneratedColumn.GENERATION_EXPRESSION_METADATA_KEY, generationExpr) } + encodeIdentityColumnSpec(metadataBuilder) StructField(name, dataType, nullable, metadataBuilder.build()) } + + private def encodeIdentityColumnSpec(metadataBuilder: MetadataBuilder): Unit = { + identityColumnSpec.foreach { spec: IdentityColumnSpec => + metadataBuilder.putLong(IdentityColumn.IDENTITY_INFO_START, spec.getStart) + metadataBuilder.putLong(IdentityColumn.IDENTITY_INFO_STEP, spec.getStep) + metadataBuilder.putBoolean( + IdentityColumn.IDENTITY_INFO_ALLOW_EXPLICIT_INSERT, + spec.isAllowExplicitInsert) + } + } } object ColumnDefinition { @@ -87,6 +103,9 @@ object ColumnDefinition { metadataBuilder.remove(CURRENT_DEFAULT_COLUMN_METADATA_KEY) metadataBuilder.remove(EXISTS_DEFAULT_COLUMN_METADATA_KEY) metadataBuilder.remove(GeneratedColumn.GENERATION_EXPRESSION_METADATA_KEY) + metadataBuilder.remove(IdentityColumn.IDENTITY_INFO_START) + metadataBuilder.remove(IdentityColumn.IDENTITY_INFO_STEP) + metadataBuilder.remove(IdentityColumn.IDENTITY_INFO_ALLOW_EXPLICIT_INSERT) val hasDefaultValue = col.getCurrentDefaultValue().isDefined && col.getExistenceDefaultValue().isDefined @@ -97,6 +116,15 @@ object ColumnDefinition { None } val generationExpr = GeneratedColumn.getGenerationExpression(col) + val identityColumnSpec = if (col.metadata.contains(IdentityColumn.IDENTITY_INFO_START)) { + Some(new IdentityColumnSpec( + col.metadata.getLong(IdentityColumn.IDENTITY_INFO_START), + col.metadata.getLong(IdentityColumn.IDENTITY_INFO_STEP), + col.metadata.getBoolean(IdentityColumn.IDENTITY_INFO_ALLOW_EXPLICIT_INSERT) + )) + } else { + None + } ColumnDefinition( col.name, col.dataType, @@ -104,6 +132,7 @@ object ColumnDefinition { col.getComment(), defaultValue, generationExpr, + identityColumnSpec, metadataBuilder.build() ) } @@ -124,18 +153,8 @@ object ColumnDefinition { s"Command $cmd should not have column default value expression.") } cmd.columns.foreach { col => - if (col.defaultValue.isDefined && col.generationExpression.isDefined) { - throw new AnalysisException( - errorClass = "GENERATED_COLUMN_WITH_DEFAULT_VALUE", - messageParameters = Map( - "colName" -> col.name, - "defaultValue" -> col.defaultValue.get.originalSQL, - "genExpr" -> col.generationExpression.get - ) - ) - } - col.defaultValue.foreach { default => + checkDefaultColumnConflicts(col) validateDefaultValueExpr(default, statement, col.name, col.dataType) } } @@ -143,6 +162,29 @@ object ColumnDefinition { case _ => } } + + private def checkDefaultColumnConflicts(col: ColumnDefinition): Unit = { + if (col.generationExpression.isDefined) { + throw new AnalysisException( + errorClass = "GENERATED_COLUMN_WITH_DEFAULT_VALUE", + messageParameters = Map( + "colName" -> col.name, + "defaultValue" -> col.defaultValue.get.originalSQL, + "genExpr" -> col.generationExpression.get + ) + ) + } + if (col.identityColumnSpec.isDefined) { + throw new AnalysisException( + errorClass = "IDENTITY_COLUMN_WITH_DEFAULT_VALUE", + messageParameters = Map( + "colName" -> col.name, + "defaultValue" -> col.defaultValue.get.originalSQL, + "identityColumnSpec" -> col.identityColumnSpec.get.toString + ) + ) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ExecutableDuringAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ExecutableDuringAnalysis.scala new file mode 100644 index 0000000000000..dc8dbf701f6a9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ExecutableDuringAnalysis.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +/** + * A logical plan node that requires execution during analysis. + */ +trait ExecutableDuringAnalysis extends LogicalPlan { + /** + * Returns the logical plan node that should be used for EXPLAIN. + */ + def stageForExplain(): LogicalPlan +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala index 4701f4ea1e172..75b2fcd3a5f34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.{Expression, NamedArgumentExpression} +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns +import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.util.ArrayImplicits._ @@ -122,12 +124,32 @@ object NamedParametersSupport { functionSignature: FunctionSignature, args: Seq[Expression], functionName: String): Seq[Expression] = { - val parameters: Seq[InputParameter] = functionSignature.parameters + defaultRearrange(functionName, functionSignature.parameters, args) + } + + final def defaultRearrange(procedure: BoundProcedure, args: Seq[Expression]): Seq[Expression] = { + defaultRearrange( + procedure.name, + procedure.parameters.map(toInputParameter).toSeq, + args) + } + + private def toInputParameter(param: ProcedureParameter): InputParameter = { + val defaultValue = Option(param.defaultValueExpression).map { expr => + ResolveDefaultColumns.analyze(param.name, param.dataType, expr, "CALL") + } + InputParameter(param.name, defaultValue) + } + + private def defaultRearrange( + routineName: String, + parameters: Seq[InputParameter], + args: Seq[Expression]): Seq[Expression] = { if (parameters.dropWhile(_.default.isEmpty).exists(_.default.isEmpty)) { - throw QueryCompilationErrors.unexpectedRequiredParameter(functionName, parameters) + throw QueryCompilationErrors.unexpectedRequiredParameter(routineName, parameters) } - val (positionalArgs, namedArgs) = splitAndCheckNamedArguments(args, functionName) + val (positionalArgs, namedArgs) = splitAndCheckNamedArguments(args, routineName) val namedParameters: Seq[InputParameter] = parameters.drop(positionalArgs.size) // The following loop checks for the following: @@ -140,12 +162,12 @@ object NamedParametersSupport { namedArgs.foreach { namedArg => val parameterName = namedArg.key if (!parameterNamesSet.contains(parameterName)) { - throw QueryCompilationErrors.unrecognizedParameterName(functionName, namedArg.key, + throw QueryCompilationErrors.unrecognizedParameterName(routineName, namedArg.key, parameterNamesSet.toSeq) } if (positionalParametersSet.contains(parameterName)) { throw QueryCompilationErrors.positionalAndNamedArgumentDoubleReference( - functionName, namedArg.key) + routineName, namedArg.key) } } @@ -154,7 +176,7 @@ object NamedParametersSupport { val validParameterSizes = Array.range(parameters.count(_.default.isEmpty), parameters.size + 1).toImmutableArraySeq throw QueryCompilationErrors.wrongNumArgsError( - functionName, validParameterSizes, args.length) + routineName, validParameterSizes, args.length) } // This constructs a map from argument name to value for argument rearrangement. @@ -168,7 +190,7 @@ object NamedParametersSupport { namedArgMap.getOrElse( param.name, if (param.default.isEmpty) { - throw QueryCompilationErrors.requiredParameterNotFound(functionName, param.name, index) + throw QueryCompilationErrors.requiredParameterNotFound(routineName, param.name, index) } else { param.default.get } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MultiResult.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MultiResult.scala new file mode 100644 index 0000000000000..f249e5c87eba2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MultiResult.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class MultiResult(children: Seq[LogicalPlan]) extends LogicalPlan { + + override def output: Seq[Attribute] = children.lastOption.map(_.output).getOrElse(Nil) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): MultiResult = { + copy(children = newChildren) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index e784e6695dbd0..90af6333b2e0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.{AliasIdentifier, SQLConfHelper} +import org.apache.spark.sql.catalyst.{AliasIdentifier, InternalRow, SQLConfHelper} import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, MultiInstanceRelation, Resolver, TypeCoercion, TypeCoercionBase, UnresolvedUnaryNode} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_PLAN @@ -992,12 +992,18 @@ object Range { castAndEval[Int](expression, IntegerType, paramIndex, paramName) } +// scalastyle:off line.size.limit @ExpressionDescription( usage = """ - _FUNC_(start: long, end: long, step: long, numSlices: integer) - _FUNC_(start: long, end: long, step: long) - _FUNC_(start: long, end: long) - _FUNC_(end: long)""", + _FUNC_(start[, end[, step[, numSlices]]]) / _FUNC_(end) - Returns a table of values within a specified range. + """, + arguments = """ + Arguments: + * start - An optional BIGINT literal defaulted to 0, marking the first value generated. + * end - A BIGINT literal marking endpoint (exclusive) of the number generation. + * step - An optional BIGINT literal defaulted to 1, specifying the increment used when generating values. + * numParts - An optional INTEGER literal specifying how the production of rows is spread across partitions. + """, examples = """ Examples: > SELECT * FROM _FUNC_(1); @@ -1023,6 +1029,7 @@ object Range { """, since = "2.0.0", group = "table_funcs") +// scalastyle:on line.size.limit case class Range( start: Long, end: Long, @@ -1440,7 +1447,7 @@ case class Offset(offsetExpr: Expression, child: LogicalPlan) extends OrderPrese } /** - * A constructor for creating a pivot, which will later be converted to a [[Project]] + * A logical plan node for creating a pivot, which will later be converted to a [[Project]] * or an [[Aggregate]] during the query analysis. * * @param groupByExprsOpt A sequence of group by expressions. This field should be None if coming @@ -1474,9 +1481,27 @@ case class Pivot( override protected def withNewChildInternal(newChild: LogicalPlan): Pivot = copy(child = newChild) } +/** + * A logical plan node for transpose, which will later be converted to a [[LocalRelation]] + * at ReplaceTranspose during the query optimization. + * + * The result of the transpose operation is held in the `data` field, and the corresponding + * schema is stored in the `output` field. The `Transpose` node does not depend on any child + * logical plans after the data has been collected and transposed. + * + * @param output A sequence of output attributes representing the schema of the transposed data. + * @param data A sequence of [[InternalRow]] containing the transposed data. + */ +case class Transpose( + output: Seq[Attribute], + data: Seq[InternalRow] = Nil +) extends LeafNode { + final override val nodePatterns: Seq[TreePattern] = Seq(TRANSPOSE) +} + /** - * A constructor for creating an Unpivot, which will later be converted to an [[Expand]] + * A logical plan node for creating an Unpivot, which will later be converted to an [[Expand]] * during the query analysis. * * Either ids or values array must be set. The ids array can be empty, @@ -1582,7 +1607,7 @@ case class Unpivot( } /** - * A constructor for creating a logical limit, which is split into two separate logical nodes: + * A logical plan node for creating a logical limit, which is split into two separate logical nodes: * a [[LocalLimit]], which is a partition local limit, followed by a [[GlobalLimit]]. * * This muds the water for clean logical/physical separation, and is done for better limit pushdown. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 05628d7b1c98e..b465e0e11612f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -19,17 +19,22 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.{SparkIllegalArgumentException, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, UnresolvedException, ViewSchemaMode} +import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, ResolvedProcedure, TypeCheckResult, UnresolvedException, UnresolvedProcedure, ViewSchemaMode} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.FunctionResource import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, MetadataAttribute, NamedExpression, UnaryExpression, Unevaluable, V2ExpressionUtils} import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema import org.apache.spark.sql.catalyst.trees.BinaryLike -import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, RowDeltaUtils, WriteDeltaProjections} +import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, truncatedString, CharVarcharUtils, RowDeltaUtils, WriteDeltaProjections} +import org.apache.spark.sql.catalyst.util.TypeUtils.{ordinalNumber, toSQLExpr} import org.apache.spark.sql.connector.catalog._ +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{IdentifierHelper, MultipartIdentifierHelper} +import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.write.{DeltaWrite, RowLevelOperation, RowLevelOperationTable, SupportsDelta, Write} +import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructField, StructType} import org.apache.spark.util.ArrayImplicits._ @@ -796,9 +801,9 @@ case class MergeIntoTable( object MergeIntoTable { def getWritePrivileges( - matchedActions: Seq[MergeAction], - notMatchedActions: Seq[MergeAction], - notMatchedBySourceActions: Seq[MergeAction]): Seq[TableWritePrivilege] = { + matchedActions: Iterable[MergeAction], + notMatchedActions: Iterable[MergeAction], + notMatchedBySourceActions: Iterable[MergeAction]): Seq[TableWritePrivilege] = { val privileges = scala.collection.mutable.HashSet.empty[TableWritePrivilege] (matchedActions.iterator ++ notMatchedActions ++ notMatchedBySourceActions).foreach { case _: DeleteAction => privileges.add(TableWritePrivilege.DELETE) @@ -1571,3 +1576,61 @@ case class SetVariable( override protected def withNewChildInternal(newChild: LogicalPlan): SetVariable = copy(sourceQuery = newChild) } + +/** + * The logical plan of the CALL statement. + */ +case class Call( + procedure: LogicalPlan, + args: Seq[Expression], + execute: Boolean = true) + extends UnaryNode with ExecutableDuringAnalysis { + + override def output: Seq[Attribute] = Nil + + override def child: LogicalPlan = procedure + + def bound: Boolean = procedure match { + case ResolvedProcedure(_, _, _: BoundProcedure) => true + case _ => false + } + + def checkArgTypes(): TypeCheckResult = { + require(resolved && bound, "can check arg types only after resolution and binding") + + val params = procedure match { + case ResolvedProcedure(_, _, bound: BoundProcedure) => bound.parameters + } + require(args.length == params.length, "number of args and params must match after binding") + + args.zip(params).zipWithIndex.collectFirst { + case ((arg, param), idx) + if !DataType.equalsIgnoreCompatibleNullability(arg.dataType, param.dataType) => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(idx), + "requiredType" -> toSQLType(param.dataType), + "inputSql" -> toSQLExpr(arg), + "inputType" -> toSQLType(arg.dataType))) + }.getOrElse(TypeCheckSuccess) + } + + override def simpleString(maxFields: Int): String = { + val name = procedure match { + case ResolvedProcedure(catalog, ident, _) => + s"${quoteIfNeeded(catalog.name)}.${ident.quoted}" + case UnresolvedProcedure(nameParts) => + nameParts.quoted + } + val argsString = truncatedString(args, ", ", maxFields) + s"Call $name($argsString)" + } + + override def stageForExplain(): Call = { + copy(execute = false) + } + + override protected def withNewChildInternal(newChild: LogicalPlan): Call = + copy(procedure = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 8be7aac7bebf5..b5556cbae7cd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -54,6 +54,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveDeserializer" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveEncodersInUDF" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions" :: + "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveProcedures" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGenerate" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGroupingAnalytics" :: "org.apache.spark.sql.catalyst.analysis.ResolveHigherOrderFunctions" :: @@ -70,6 +71,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubqueryColumnAliases" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveTables" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveTempViews" :: + "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveTranspose" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUnpivot" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUserSpecifiedColumns" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 6258bd615b440..0f1c98b53e0b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -72,6 +72,7 @@ object TreePattern extends Enumeration { val NOT: Value = Value val NULL_CHECK: Value = Value val NULL_LITERAL: Value = Value + val PIPE_OPERATOR_SELECT: Value = Value val SERIALIZE_FROM_OBJECT: Value = Value val OR: Value = Value val OUTER_REFERENCE: Value = Value @@ -92,6 +93,7 @@ object TreePattern extends Enumeration { val SUM: Value = Value val TIME_WINDOW: Value = Value val TIME_ZONE_AWARE_EXPRESSION: Value = Value + val TRANSPOSE: Value = Value val TRUE_OR_FALSE_LITERAL: Value = Value val VARIANT_GET: Value = Value val WINDOW_EXPRESSION: Value = Value @@ -155,8 +157,10 @@ object TreePattern extends Enumeration { // Unresolved Plan patterns (Alphabetically ordered) val UNRESOLVED_FUNC: Value = Value + val UNRESOLVED_PROCEDURE: Value = Value val UNRESOLVED_SUBQUERY_COLUMN_ALIAS: Value = Value val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value + val UNRESOLVED_TRANSPOSE: Value = Value val UNRESOLVED_TVF_ALIASES: Value = Value // Execution expression patterns (alphabetically ordered) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index f1c36f2f5c28f..e27ce29fc2318 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -389,7 +389,7 @@ object DateTimeUtils extends SparkDateTimeUtils { case "SA" | "SAT" | "SATURDAY" => SATURDAY case _ => throw new SparkIllegalArgumentException( - errorClass = "_LEGACY_ERROR_TEMP_3209", + errorClass = "ILLEGAL_DAY_OF_WEEK", messageParameters = Map("string" -> string.toString)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IdentityColumn.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IdentityColumn.scala new file mode 100644 index 0000000000000..26a3cb026d317 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IdentityColumn.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.connector.catalog.{Identifier, IdentityColumnSpec, TableCatalog, TableCatalogCapability} +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * This object contains utility methods and values for Identity Columns + */ +object IdentityColumn { + val IDENTITY_INFO_START = "identity.start" + val IDENTITY_INFO_STEP = "identity.step" + val IDENTITY_INFO_ALLOW_EXPLICIT_INSERT = "identity.allowExplicitInsert" + + /** + * If `schema` contains any generated columns, check whether the table catalog supports identity + * columns. Otherwise throw an error. + */ + def validateIdentityColumn( + schema: StructType, + catalog: TableCatalog, + ident: Identifier): Unit = { + if (hasIdentityColumns(schema)) { + if (!catalog + .capabilities() + .contains(TableCatalogCapability.SUPPORTS_CREATE_TABLE_WITH_IDENTITY_COLUMNS)) { + throw QueryCompilationErrors.unsupportedTableOperationError( + catalog, ident, operation = "identity column" + ) + } + } + } + + /** + * Whether the given `field` is an identity column + */ + def isIdentityColumn(field: StructField): Boolean = { + field.metadata.contains(IDENTITY_INFO_START) + } + + /** + * Returns the identity information stored in the column metadata if it exists + */ + def getIdentityInfo(field: StructField): Option[IdentityColumnSpec] = { + if (isIdentityColumn(field)) { + Some(new IdentityColumnSpec( + field.metadata.getString(IDENTITY_INFO_START).toLong, + field.metadata.getString(IDENTITY_INFO_STEP).toLong, + field.metadata.getString(IDENTITY_INFO_ALLOW_EXPLICIT_INSERT).toBoolean)) + } else { + None + } + } + + /** + * Whether the `schema` has one or more identity columns + */ + def hasIdentityColumns(schema: StructType): Boolean = { + schema.exists(isIdentityColumn) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 0067114e36fdd..90c802b7e28df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -105,16 +105,18 @@ object IntervalUtils extends SparkIntervalUtils { endField: Byte, intervalStr: String, typeName: String, - fallBackNotice: Option[String] = None) = { + fallBackNotice: Boolean = false) = { throw new SparkIllegalArgumentException( - errorClass = "_LEGACY_ERROR_TEMP_3214", + errorClass = { + if (fallBackNotice) "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE" + else "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING" + }, messageParameters = Map( "intervalStr" -> intervalStr, "supportedFormat" -> supportedFormat((intervalStr, startFiled, endField)) .map(format => s"`$format`").mkString(", "), "typeName" -> typeName, - "input" -> input.toString, - "fallBackNotice" -> fallBackNotice.map(s => s", $s").getOrElse(""))) + "input" -> input.toString)) } val supportedFormat = Map( @@ -145,14 +147,15 @@ object IntervalUtils extends SparkIntervalUtils { def checkTargetType(targetStartField: Byte, targetEndField: Byte): Boolean = startField == targetStartField && endField == targetEndField - input.trimAll().toString match { + val trimmedInput = input.trimAll().toString + trimmedInput match { case yearMonthRegex(sign, year, month) if checkTargetType(YM.YEAR, YM.MONTH) => - toYMInterval(year, month, finalSign(sign)) + toYMInterval(year, month, trimmedInput, finalSign(sign)) case yearMonthLiteralRegex(firstSign, secondSign, year, month) if checkTargetType(YM.YEAR, YM.MONTH) => - toYMInterval(year, month, finalSign(firstSign, secondSign)) + toYMInterval(year, month, trimmedInput, finalSign(firstSign, secondSign)) case yearMonthIndividualRegex(firstSign, value) => - safeToInterval("year-month") { + safeToInterval("year-month", trimmedInput) { val sign = finalSign(firstSign) if (endField == YM.YEAR) { sign * Math.toIntExact(value.toLong * MONTHS_PER_YEAR) @@ -164,7 +167,7 @@ object IntervalUtils extends SparkIntervalUtils { } } case yearMonthIndividualLiteralRegex(firstSign, secondSign, value, unit) => - safeToInterval("year-month") { + safeToInterval("year-month", trimmedInput) { val sign = finalSign(firstSign, secondSign) unit.toUpperCase(Locale.ROOT) match { case "YEAR" if checkTargetType(YM.YEAR, YM.YEAR) => @@ -202,21 +205,21 @@ object IntervalUtils extends SparkIntervalUtils { new CalendarInterval(months, 0, 0) } - private def safeToInterval[T](interval: String)(f: => T): T = { + private def safeToInterval[T](interval: String, input: String)(f: => T): T = { try { f } catch { case e: SparkThrowable => throw e case NonFatal(e) => throw new SparkIllegalArgumentException( - errorClass = "_LEGACY_ERROR_TEMP_3213", - messageParameters = Map("interval" -> interval, "msg" -> e.getMessage), + errorClass = "INVALID_INTERVAL_FORMAT.INTERVAL_PARSING", + messageParameters = Map("input" -> input, "interval" -> interval), cause = e) } } - private def toYMInterval(year: String, month: String, sign: Int): Int = { - safeToInterval("year-month") { + private def toYMInterval(year: String, month: String, input: String, sign: Int): Int = { + safeToInterval("year-month", input) { val years = toLongWithRange(yearStr, year, 0, Integer.MAX_VALUE / MONTHS_PER_YEAR) val totalMonths = sign * (years * MONTHS_PER_YEAR + toLongWithRange(monthStr, month, 0, 11)) @@ -285,7 +288,8 @@ object IntervalUtils extends SparkIntervalUtils { def checkTargetType(targetStartField: Byte, targetEndField: Byte): Boolean = startField == targetStartField && endField == targetEndField - input.trimAll().toString match { + val trimmedInput = input.trimAll().toString + trimmedInput match { case dayHourRegex(sign, day, hour) if checkTargetType(DT.DAY, DT.HOUR) => toDTInterval(day, hour, "0", "0", finalSign(sign)) case dayHourLiteralRegex(firstSign, secondSign, day, hour) @@ -324,7 +328,7 @@ object IntervalUtils extends SparkIntervalUtils { toDTInterval(minute, secondAndMicro(second, micro), finalSign(firstSign, secondSign)) case dayTimeIndividualRegex(firstSign, value, suffix) => - safeToInterval("day-time") { + safeToInterval("day-time", trimmedInput) { val sign = finalSign(firstSign) (startField, endField) match { case (DT.DAY, DT.DAY) if suffix == null && value.length <= 9 => @@ -339,11 +343,11 @@ object IntervalUtils extends SparkIntervalUtils { case -1 => parseSecondNano(s"-${secondAndMicro(value, suffix)}") } case (_, _) => throwIllegalIntervalFormatException(input, startField, endField, - "day-time", DT(startField, endField).typeName, Some(fallbackNotice)) + "day-time", DT(startField, endField).typeName, true) } } case dayTimeIndividualLiteralRegex(firstSign, secondSign, value, suffix, unit) => - safeToInterval("day-time") { + safeToInterval("day-time", trimmedInput) { val sign = finalSign(firstSign, secondSign) unit.toUpperCase(Locale.ROOT) match { case "DAY" if suffix == null && value.length <= 9 && checkTargetType(DT.DAY, DT.DAY) => @@ -360,11 +364,11 @@ object IntervalUtils extends SparkIntervalUtils { case -1 => parseSecondNano(s"-${secondAndMicro(value, suffix)}") } case _ => throwIllegalIntervalFormatException(input, startField, endField, - "day-time", DT(startField, endField).typeName, Some(fallbackNotice)) + "day-time", DT(startField, endField).typeName, true) } } case _ => throwIllegalIntervalFormatException(input, startField, endField, - "day-time", DT(startField, endField).typeName, Some(fallbackNotice)) + "day-time", DT(startField, endField).typeName, true) } } @@ -512,7 +516,7 @@ object IntervalUtils extends SparkIntervalUtils { case DT.SECOND => // No-op case _ => throw new SparkIllegalArgumentException( - errorClass = "_LEGACY_ERROR_TEMP_3212", + errorClass = "INVALID_INTERVAL_FORMAT.UNSUPPORTED_FROM_TO_EXPRESSION", messageParameters = Map( "input" -> input, "from" -> DT.fieldToString(from), @@ -524,10 +528,14 @@ object IntervalUtils extends SparkIntervalUtils { micros = Math.addExact(micros, Math.multiplyExact(seconds, MICROS_PER_SECOND)) new CalendarInterval(0, sign * days, sign * micros) } catch { + // Bypass SparkIllegalArgumentExceptions + case se: SparkIllegalArgumentException => throw se case e: Exception => throw new SparkIllegalArgumentException( - errorClass = "_LEGACY_ERROR_TEMP_3211", - messageParameters = Map("msg" -> e.getMessage), + errorClass = "INVALID_INTERVAL_FORMAT.DAY_TIME_PARSING", + messageParameters = Map( + "msg" -> e.getMessage, + "input" -> input), cause = e) } } @@ -564,7 +572,9 @@ object IntervalUtils extends SparkIntervalUtils { case Array(secondsStr, nanosStr) => val seconds = parseSeconds(secondsStr) Math.addExact(seconds, parseNanos(nanosStr, seconds < 0)) - case _ => throw new SparkIllegalArgumentException("_LEGACY_ERROR_TEMP_3210") + case _ => throw new SparkIllegalArgumentException( + errorClass = "INVALID_INTERVAL_FORMAT.SECOND_NANO_FORMAT", + messageParameters = Map("input" -> secondNano)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala index 8b7392e71249e..693ac8d94dbcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkThrowable, SparkUnsupportedOperationException} +import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.AnalysisException @@ -412,8 +412,11 @@ object ResolveDefaultColumns extends QueryErrorsBase case _: ExprLiteral | _: Cast => expr } } catch { - case _: AnalysisException | _: MatchError => - throw QueryCompilationErrors.failedToParseExistenceDefaultAsLiteral(field.name, text) + // AnalysisException thrown from analyze is already formatted, throw it directly. + case ae: AnalysisException => throw ae + case _: MatchError => + throw SparkException.internalError(s"parse existence default as literal err," + + s" field name: ${field.name}, value: $text") } // The expression should be a literal value by this point, possibly wrapped in a cast // function. This is enforced by the execution of commands that assign default values. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 65bdae85be12a..282350dda67d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -126,6 +126,13 @@ private[sql] object CatalogV2Implicits { case _ => throw QueryCompilationErrors.missingCatalogAbilityError(plugin, "functions") } + + def asProcedureCatalog: ProcedureCatalog = plugin match { + case procedureCatalog: ProcedureCatalog => + procedureCatalog + case _ => + throw QueryCompilationErrors.missingCatalogAbilityError(plugin, "procedures") + } } implicit class NamespaceHelper(namespace: Array[String]) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 6698f0a021400..9b7f68070a1a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.{AsOfTimestamp, AsOfVersion, Named import org.apache.spark.sql.catalyst.catalog.ClusterBySpec import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical.{SerdeInfo, TableSpec} -import org.apache.spark.sql.catalyst.util.GeneratedColumn +import org.apache.spark.sql.catalyst.util.{GeneratedColumn, IdentityColumn} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.functions.UnboundFunction @@ -579,18 +579,10 @@ private[sql] object CatalogV2Util { val isDefaultColumn = f.getCurrentDefaultValue().isDefined && f.getExistenceDefaultValue().isDefined val isGeneratedColumn = GeneratedColumn.isGeneratedColumn(f) - if (isDefaultColumn && isGeneratedColumn) { - throw new AnalysisException( - errorClass = "GENERATED_COLUMN_WITH_DEFAULT_VALUE", - messageParameters = Map( - "colName" -> f.name, - "defaultValue" -> f.getCurrentDefaultValue().get, - "genExpr" -> GeneratedColumn.getGenerationExpression(f).get - ) - ) - } - + val isIdentityColumn = IdentityColumn.isIdentityColumn(f) if (isDefaultColumn) { + checkDefaultColumnConflicts(f) + val e = analyze( f, statementType = "Column analysis", @@ -611,10 +603,41 @@ private[sql] object CatalogV2Util { Seq("comment", GeneratedColumn.GENERATION_EXPRESSION_METADATA_KEY)) Column.create(f.name, f.dataType, f.nullable, f.getComment().orNull, GeneratedColumn.getGenerationExpression(f).get, metadataAsJson(cleanedMetadata)) + } else if (isIdentityColumn) { + val cleanedMetadata = metadataWithKeysRemoved( + Seq("comment", + IdentityColumn.IDENTITY_INFO_START, + IdentityColumn.IDENTITY_INFO_STEP, + IdentityColumn.IDENTITY_INFO_ALLOW_EXPLICIT_INSERT)) + Column.create(f.name, f.dataType, f.nullable, f.getComment().orNull, + IdentityColumn.getIdentityInfo(f).get, metadataAsJson(cleanedMetadata)) } else { val cleanedMetadata = metadataWithKeysRemoved(Seq("comment")) Column.create(f.name, f.dataType, f.nullable, f.getComment().orNull, metadataAsJson(cleanedMetadata)) } } + + private def checkDefaultColumnConflicts(f: StructField): Unit = { + if (GeneratedColumn.isGeneratedColumn(f)) { + throw new AnalysisException( + errorClass = "GENERATED_COLUMN_WITH_DEFAULT_VALUE", + messageParameters = Map( + "colName" -> f.name, + "defaultValue" -> f.getCurrentDefaultValue().get, + "genExpr" -> GeneratedColumn.getGenerationExpression(f).get + ) + ) + } + if (IdentityColumn.isIdentityColumn(f)) { + throw new AnalysisException( + errorClass = "IDENTITY_COLUMN_WITH_DEFAULT_VALUE", + messageParameters = Map( + "colName" -> f.name, + "defaultValue" -> f.getCurrentDefaultValue().get, + "identityColumnSpec" -> IdentityColumn.getIdentityInfo(f).get.toString + ) + ) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index fa8ea2f5289fa..0b5255e95f073 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -853,6 +853,13 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat origin = origin) } + def failedToLoadRoutineError(nameParts: Seq[String], e: Exception): Throwable = { + new AnalysisException( + errorClass = "FAILED_TO_LOAD_ROUTINE", + messageParameters = Map("routineName" -> toSQLId(nameParts)), + cause = Some(e)) + } + def unresolvedRoutineError(name: FunctionIdentifier, searchPath: Seq[String]): Throwable = { new AnalysisException( errorClass = "UNRESOLVED_ROUTINE", @@ -3516,29 +3523,21 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "cond" -> toSQLExpr(cond))) } - def failedToParseExistenceDefaultAsLiteral(fieldName: String, defaultValue: String): Throwable = { - new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1344", - messageParameters = Map( - "fieldName" -> fieldName, - "defaultValue" -> defaultValue)) - } - def defaultReferencesNotAllowedInDataSource( statementType: String, dataSource: String): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1345", + errorClass = "DEFAULT_UNSUPPORTED", messageParameters = Map( - "statementType" -> statementType, + "statementType" -> toSQLStmt(statementType), "dataSource" -> dataSource)) } def addNewDefaultColumnToExistingTableNotAllowed( statementType: String, dataSource: String): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1346", + errorClass = "ADD_DEFAULT_UNSUPPORTED", messageParameters = Map( - "statementType" -> statementType, + "statementType" -> toSQLStmt(statementType), "dataSource" -> dataSource)) } @@ -3959,6 +3958,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat messageParameters = Map("provider" -> name)) } + def externalDataSourceException(cause: Throwable): Throwable = { + new AnalysisException( + errorClass = "DATA_SOURCE_EXTERNAL_ERROR", + messageParameters = Map(), + cause = Some(cause) + ) + } + def foundMultipleDataSources(provider: String): Throwable = { new AnalysisException( errorClass = "FOUND_MULTIPLE_DATA_SOURCES", @@ -4090,10 +4097,33 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ) } + def avroOptionsException(optionName: String, message: String): Throwable = { + new AnalysisException( + errorClass = "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", + messageParameters = Map("optionName" -> optionName, "message" -> message) + ) + } + def protobufNotLoadedSqlFunctionsUnusable(functionName: String): Throwable = { new AnalysisException( errorClass = "PROTOBUF_NOT_LOADED_SQL_FUNCTIONS_UNUSABLE", messageParameters = Map("functionName" -> functionName) ) } + + def pipeOperatorSelectContainsAggregateFunction(expr: Expression): Throwable = { + new AnalysisException( + errorClass = "PIPE_OPERATOR_SELECT_CONTAINS_AGGREGATE_FUNCTION", + messageParameters = Map( + "expr" -> expr.toString), + origin = expr.origin) + } + + def inlineTableContainsScalarSubquery(inlineTable: LogicalPlan): Throwable = { + new AnalysisException( + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.SCALAR_SUBQUERY_IN_VALUES", + messageParameters = Map.empty, + origin = inlineTable.origin + ) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index d6e23fcc65cd4..4bc071155012b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -79,8 +79,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map( "value" -> toSQLValue(t, from), "sourceType" -> toSQLType(from), - "targetType" -> toSQLType(to), - "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), + "targetType" -> toSQLType(to)), context = Array.empty, summary = "") } @@ -124,8 +123,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map( "expression" -> toSQLValue(s, StringType), "sourceType" -> toSQLType(StringType), - "targetType" -> toSQLType(BooleanType), - "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), + "targetType" -> toSQLType(BooleanType)), context = getQueryContext(context), summary = getSummary(context)) } @@ -139,8 +137,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map( "expression" -> toSQLValue(s, StringType), "sourceType" -> toSQLType(StringType), - "targetType" -> toSQLType(to), - "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), + "targetType" -> toSQLType(to)), context = getQueryContext(context), summary = getSummary(context)) } @@ -387,18 +384,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE s"The aggregate window function ${toSQLId(funcName)} does not support merging.") } - def dataTypeUnexpectedError(dataType: DataType): SparkUnsupportedOperationException = { - new SparkUnsupportedOperationException( - errorClass = "_LEGACY_ERROR_TEMP_2011", - messageParameters = Map("dataType" -> dataType.catalogString)) - } - - def typeUnsupportedError(dataType: DataType): SparkIllegalArgumentException = { - new SparkIllegalArgumentException( - errorClass = "_LEGACY_ERROR_TEMP_2011", - messageParameters = Map("dataType" -> dataType.toString())) - } - def negativeValueUnexpectedError( frequencyExpression : Expression): SparkIllegalArgumentException = { new SparkIllegalArgumentException( @@ -1310,10 +1295,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map("badRecord" -> badRecord)) } - def elementsOfTupleExceedLimitError(): SparkUnsupportedOperationException = { - new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_2150") - } - def expressionDecodingError(e: Exception, expressions: Seq[Expression]): SparkRuntimeException = { new SparkRuntimeException( errorClass = "EXPRESSION_DECODING_FAILED", @@ -1895,17 +1876,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = null) } - def notPublicClassError(name: String): SparkUnsupportedOperationException = { - new SparkUnsupportedOperationException( - errorClass = "_LEGACY_ERROR_TEMP_2229", - messageParameters = Map( - "name" -> name)) - } - - def primitiveTypesNotSupportedError(): SparkUnsupportedOperationException = { - new SparkUnsupportedOperationException(errorClass = "_LEGACY_ERROR_TEMP_2230") - } - def onlySupportDataSourcesProvidingFileFormatError(providingClass: String): Throwable = { new SparkException( errorClass = "_LEGACY_ERROR_TEMP_2233", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala index 61661b1d32f39..7f13dc334e06e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.errors import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.errors.QueryExecutionErrors.toSQLStmt import org.apache.spark.sql.exceptions.SqlScriptingException @@ -32,7 +33,7 @@ private[sql] object SqlScriptingErrors { origin = origin, errorClass = "LABELS_MISMATCH", cause = null, - messageParameters = Map("beginLabel" -> beginLabel, "endLabel" -> endLabel)) + messageParameters = Map("beginLabel" -> toSQLId(beginLabel), "endLabel" -> toSQLId(endLabel))) } def endLabelWithoutBeginLabel(origin: Origin, endLabel: String): Throwable = { @@ -40,29 +41,27 @@ private[sql] object SqlScriptingErrors { origin = origin, errorClass = "END_LABEL_WITHOUT_BEGIN_LABEL", cause = null, - messageParameters = Map("endLabel" -> endLabel)) + messageParameters = Map("endLabel" -> toSQLId(endLabel))) } def variableDeclarationNotAllowedInScope( origin: Origin, - varName: String, - lineNumber: String): Throwable = { + varName: Seq[String]): Throwable = { new SqlScriptingException( origin = origin, errorClass = "INVALID_VARIABLE_DECLARATION.NOT_ALLOWED_IN_SCOPE", cause = null, - messageParameters = Map("varName" -> varName, "lineNumber" -> lineNumber)) + messageParameters = Map("varName" -> toSQLId(varName))) } def variableDeclarationOnlyAtBeginning( origin: Origin, - varName: String, - lineNumber: String): Throwable = { + varName: Seq[String]): Throwable = { new SqlScriptingException( origin = origin, errorClass = "INVALID_VARIABLE_DECLARATION.ONLY_AT_BEGINNING", cause = null, - messageParameters = Map("varName" -> varName, "lineNumber" -> lineNumber)) + messageParameters = Map("varName" -> toSQLId(varName))) } def invalidBooleanStatement( @@ -84,4 +83,27 @@ private[sql] object SqlScriptingErrors { cause = null, messageParameters = Map("invalidStatement" -> toSQLStmt(stmt))) } + + def labelDoesNotExist( + origin: Origin, + labelName: String, + statementType: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "INVALID_LABEL_USAGE.DOES_NOT_EXIST", + cause = null, + messageParameters = Map( + "labelName" -> toSQLStmt(labelName), + "statementType" -> statementType)) + } + + def invalidIterateLabelUsageForCompound( + origin: Origin, + labelName: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "INVALID_LABEL_USAGE.ITERATE_IN_COMPOUND", + cause = null, + messageParameters = Map("labelName" -> toSQLStmt(labelName))) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlScriptingException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlScriptingException.scala index 4354e7e3635e4..f0c28c95046eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlScriptingException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlScriptingException.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.exceptions.SqlScriptingException.errorMessageWithLin class SqlScriptingException ( errorClass: String, cause: Throwable, - origin: Origin, + val origin: Origin, messageParameters: Map[String, String] = Map.empty) extends Exception( errorMessageWithLineNumber(Option(origin), errorClass, messageParameters), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 72915f0e5c256..6c3e9bac1cfe5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1843,6 +1843,13 @@ object SQLConf { .intConf .createWithDefault(10000) + val DATAFRAME_TRANSPOSE_MAX_VALUES = buildConf("spark.sql.transposeMaxValues") + .doc("When doing a transpose without specifying values for the index column this is" + + " the maximum number of values that will be transposed without error.") + .version("4.0.0") + .intConf + .createWithDefault(500) + val RUN_SQL_ON_FILES = buildConf("spark.sql.runSQLOnFiles") .internal() .doc("When true, we could use `datasource`.`path` as table in SQL query.") @@ -3162,6 +3169,29 @@ object SQLConf { .version("4.0.0") .fallbackConf(Python.PYTHON_WORKER_FAULTHANLDER_ENABLED) + val PYSPARK_PLOT_MAX_ROWS = + buildConf("spark.sql.pyspark.plotting.max_rows") + .doc( + "The visual limit on top-n-based plots. If set to 1000, the first 1000 data points " + + "will be used for plotting.") + .version("4.0.0") + .intConf + .createWithDefault(1000) + + val PYSPARK_PLOT_SAMPLE_RATIO = + buildConf("spark.sql.pyspark.plotting.sample_ratio") + .doc( + "The proportion of data that will be plotted for sample-based plots. It is determined " + + "based on spark.sql.pyspark.plotting.max_rows if not explicitly set." + ) + .version("4.0.0") + .doubleConf + .checkValue( + ratio => ratio >= 0.0 && ratio <= 1.0, + "The value should be between 0.0 and 1.0 inclusive." + ) + .createOptional + val ARROW_SPARKR_EXECUTION_ENABLED = buildConf("spark.sql.execution.arrow.sparkr.enabled") .doc("When true, make use of Apache Arrow for columnar data transfers in SparkR. " + @@ -3820,6 +3850,15 @@ object SQLConf { .intConf .createWithDefault(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) + val PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN = + buildConf("spark.databricks.sql.optimizer.pruneFiltersCanPruneStreamingSubplan") + .internal() + .doc("Allow PruneFilters to remove streaming subplans when we encounter a false filter. " + + "This flag is to restore prior buggy behavior for broken pipelines.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -4982,6 +5021,15 @@ object SQLConf { .stringConf .createWithDefault("versionAsOf") + val OPERATOR_PIPE_SYNTAX_ENABLED = + buildConf("spark.sql.operatorPipeSyntaxEnabled") + .doc("If true, enable operator pipe syntax for Apache Spark SQL. This uses the operator " + + "pipe marker |> to indicate separation between clauses of SQL in a manner that describes " + + "the sequence of steps that the query performs in a composable fashion.") + .version("4.0.0") + .booleanConf + .createWithDefault(Utils.isTesting) + val LEGACY_PERCENTILE_DISC_CALCULATION = buildConf("spark.sql.legacy.percentileDiscCalculation") .internal() .doc("If true, the old bogus percentile_disc calculation is used. The old calculation " + @@ -5712,6 +5760,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def dataFramePivotMaxValues: Int = getConf(DATAFRAME_PIVOT_MAX_VALUES) + def dataFrameTransposeMaxValues: Int = getConf(DATAFRAME_TRANSPOSE_MAX_VALUES) + def runSQLonFile: Boolean = getConf(RUN_SQL_ON_FILES) def enableTwoLevelAggMap: Boolean = getConf(ENABLE_TWOLEVEL_AGG_MAP) @@ -5846,6 +5896,10 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def pythonUDFWorkerFaulthandlerEnabled: Boolean = getConf(PYTHON_UDF_WORKER_FAULTHANLDER_ENABLED) + def pysparkPlotMaxRows: Int = getConf(PYSPARK_PLOT_MAX_ROWS) + + def pysparkPlotSampleRatio: Option[Double] = getConf(PYSPARK_PLOT_SAMPLE_RATIO) + def arrowSparkREnabled: Boolean = getConf(ARROW_SPARKR_EXECUTION_ENABLED) def arrowPySparkFallbackEnabled: Boolean = getConf(ARROW_PYSPARK_FALLBACK_ENABLED) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ColumnImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ColumnImpl.scala index 2a67ffc4bbef5..47889410561e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ColumnImpl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ColumnImpl.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.internal.connector -import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue} +import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, IdentityColumnSpec} import org.apache.spark.sql.types.DataType // The standard concrete implementation of data source V2 column. @@ -28,4 +28,5 @@ case class ColumnImpl( comment: String, defaultValue: ColumnDefaultValue, generationExpression: String, + identityColumnSpec: IdentityColumnSpec, metadataInJSON: String) extends Column diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala index 40e6182e587b3..50e933ba97ae6 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtilsSuite.scala @@ -349,7 +349,7 @@ class ExpressionImplUtilsSuite extends SparkFunSuite { exception = intercept[SparkRuntimeException] { f(t) }, - errorClass = t.expectedErrorClassOpt.get, + condition = t.expectedErrorClassOpt.get, parameters = t.errorParamsMap ) } @@ -361,7 +361,7 @@ class ExpressionImplUtilsSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { ExpressionImplUtils.validateUTF8String(str) }, - errorClass = "INVALID_UTF8_STRING", + condition = "INVALID_UTF8_STRING", parameters = Map( "str" -> str.getBytes.map(byte => f"\\x$byte%02X").mkString ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowJsonSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowJsonSuite.scala index 3e72dc7da24b7..cf50063baa13c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowJsonSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowJsonSuite.scala @@ -136,7 +136,7 @@ class RowJsonSuite extends SparkFunSuite { new StructType().add("a", ObjectType(classOf[(Int, Int)]))) row.jsonValue }, - errorClass = "FAILED_ROW_TO_JSON", + condition = "FAILED_ROW_TO_JSON", parameters = Map( "value" -> toSQLValue("(1,2)"), "class" -> "class scala.Tuple2$mcII$sp", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index 562aac766fc33..7572843f44a19 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -108,7 +108,7 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { CatalystTypeConverters.createToCatalystConverter(structType)("test") }, - errorClass = "_LEGACY_ERROR_TEMP_3219", + condition = "_LEGACY_ERROR_TEMP_3219", parameters = Map( "other" -> "test", "otherClass" -> "java.lang.String", @@ -121,7 +121,7 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { CatalystTypeConverters.createToCatalystConverter(mapType)("test") }, - errorClass = "_LEGACY_ERROR_TEMP_3221", + condition = "_LEGACY_ERROR_TEMP_3221", parameters = Map( "other" -> "test", "otherClass" -> "java.lang.String", @@ -135,7 +135,7 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { CatalystTypeConverters.createToCatalystConverter(arrayType)("test") }, - errorClass = "_LEGACY_ERROR_TEMP_3220", + condition = "_LEGACY_ERROR_TEMP_3220", parameters = Map( "other" -> "test", "otherClass" -> "java.lang.String", @@ -148,7 +148,7 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { CatalystTypeConverters.createToCatalystConverter(decimalType)("test") }, - errorClass = "_LEGACY_ERROR_TEMP_3219", + condition = "_LEGACY_ERROR_TEMP_3219", parameters = Map( "other" -> "test", "otherClass" -> "java.lang.String", @@ -160,7 +160,7 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { CatalystTypeConverters.createToCatalystConverter(StringType)(0.1) }, - errorClass = "_LEGACY_ERROR_TEMP_3219", + condition = "_LEGACY_ERROR_TEMP_3219", parameters = Map( "other" -> "0.1", "otherClass" -> "java.lang.Double", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index daa8d12613f2c..a09dadbcd4816 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -494,7 +494,7 @@ class ScalaReflectionSuite extends SparkFunSuite { exception = intercept[SparkUnsupportedOperationException] { schemaFor[TraitProductWithoutCompanion] }, - errorClass = "_LEGACY_ERROR_TEMP_2144", + condition = "_LEGACY_ERROR_TEMP_2144", parameters = Map("tpe" -> "org.apache.spark.sql.catalyst.TraitProductWithoutCompanion")) } @@ -503,7 +503,7 @@ class ScalaReflectionSuite extends SparkFunSuite { exception = intercept[SparkUnsupportedOperationException] { schemaFor[TraitProductWithNoConstructorCompanion] }, - errorClass = "_LEGACY_ERROR_TEMP_2144", + condition = "_LEGACY_ERROR_TEMP_2144", parameters = Map("tpe" -> "org.apache.spark.sql.catalyst.TraitProductWithNoConstructorCompanion")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala index 8a71496607466..fc5d39fd9c2bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala @@ -474,7 +474,7 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkUnsupportedOperationException] { RangeShuffleSpec(10, distribution).createPartitioning(distribution.clustering) }, - errorClass = "UNSUPPORTED_CALL.WITHOUT_SUGGESTION", + condition = "UNSUPPORTED_CALL.WITHOUT_SUGGESTION", parameters = Map( "methodName" -> "createPartitioning$", "className" -> "org.apache.spark.sql.catalyst.plans.physical.ShuffleSpec")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index cfe08e1895363..70cc50a23a6a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -110,14 +110,15 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { } } - def errorClassTest( + def errorConditionTest( name: String, plan: LogicalPlan, - errorClass: String, + condition: String, messageParameters: Map[String, String], caseSensitive: Boolean = true): Unit = { test(name) { - assertAnalysisErrorClass(plan, errorClass, messageParameters, caseSensitive = caseSensitive) + assertAnalysisErrorCondition( + plan, condition, messageParameters, caseSensitive = caseSensitive) } } @@ -134,10 +135,10 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { testRelation.select(ScalarSubquery(LocalRelation()).as("a")), "Scalar subquery must return only one column, but got 0" :: Nil) - errorClassTest( + errorConditionTest( "single invalid type, single arg", testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as("a")), - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", messageParameters = Map( "sqlExpr" -> "\"testfunction(NULL)\"", "paramIndex" -> "first", @@ -145,11 +146,11 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { "inputType" -> "\"DATE\"", "requiredType" -> "\"INT\"")) - errorClassTest( + errorConditionTest( "single invalid type, second arg", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as("a")), - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", messageParameters = Map( "sqlExpr" -> "\"testfunction(NULL, NULL)\"", "paramIndex" -> "second", @@ -157,11 +158,11 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { "inputType" -> "\"DATE\"", "requiredType" -> "\"INT\"")) - errorClassTest( + errorConditionTest( "multiple invalid type", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as("a")), - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", messageParameters = Map( "sqlExpr" -> "\"testfunction(NULL, NULL)\"", "paramIndex" -> "first", @@ -169,17 +170,17 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { "inputType" -> "\"DATE\"", "requiredType" -> "\"INT\"")) - errorClassTest( + errorConditionTest( "SPARK-44477: type check failure", testRelation.select( TestFunctionWithTypeCheckFailure(dateLit :: Nil, BinaryType :: Nil).as("a")), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", messageParameters = Map( "sqlExpr" -> "\"testfunctionwithtypecheckfailure(NULL)\"", "msg" -> "Expression must be a binary", "hint" -> "")) - errorClassTest( + errorConditionTest( "invalid window function", testRelation2.select( WindowExpression( @@ -188,10 +189,10 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { UnresolvedAttribute("a") :: Nil, SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, UnspecifiedFrame)).as("window")), - errorClass = "UNSUPPORTED_EXPR_FOR_WINDOW", + condition = "UNSUPPORTED_EXPR_FOR_WINDOW", messageParameters = Map("sqlExpr" -> "\"0\"")) - errorClassTest( + errorConditionTest( "distinct aggregate function in window", testRelation2.select( WindowExpression( @@ -200,7 +201,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { UnresolvedAttribute("a") :: Nil, SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, UnspecifiedFrame)).as("window")), - errorClass = "DISTINCT_WINDOW_FUNCTION_UNSUPPORTED", + condition = "DISTINCT_WINDOW_FUNCTION_UNSUPPORTED", messageParameters = Map("windowExpr" -> s""" |"count(DISTINCT b) OVER (PARTITION BY a ORDER BY b ASC NULLS FIRST @@ -221,9 +222,9 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { ) test("distinct function") { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( CatalystSqlParser.parsePlan("SELECT hex(DISTINCT a) FROM TaBlE"), - expectedErrorClass = "INVALID_SQL_SYNTAX.FUNCTION_WITH_UNSUPPORTED_SYNTAX", + expectedErrorCondition = "INVALID_SQL_SYNTAX.FUNCTION_WITH_UNSUPPORTED_SYNTAX", expectedMessageParameters = Map( "prettyName" -> toSQLId("hex"), "syntax" -> toSQLStmt("DISTINCT")), @@ -231,9 +232,9 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { } test("non aggregate function with filter predicate") { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( CatalystSqlParser.parsePlan("SELECT hex(a) FILTER (WHERE c = 1) FROM TaBlE2"), - expectedErrorClass = "INVALID_SQL_SYNTAX.FUNCTION_WITH_UNSUPPORTED_SYNTAX", + expectedErrorCondition = "INVALID_SQL_SYNTAX.FUNCTION_WITH_UNSUPPORTED_SYNTAX", expectedMessageParameters = Map( "prettyName" -> toSQLId("hex"), "syntax" -> toSQLStmt("FILTER CLAUSE")), @@ -241,9 +242,9 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { } test("distinct window function") { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( CatalystSqlParser.parsePlan("SELECT percent_rank(DISTINCT a) OVER () FROM TaBlE"), - expectedErrorClass = "INVALID_SQL_SYNTAX.FUNCTION_WITH_UNSUPPORTED_SYNTAX", + expectedErrorCondition = "INVALID_SQL_SYNTAX.FUNCTION_WITH_UNSUPPORTED_SYNTAX", expectedMessageParameters = Map( "prettyName" -> toSQLId("percent_rank"), "syntax" -> toSQLStmt("DISTINCT")), @@ -251,10 +252,10 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { } test("window function with filter predicate") { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( CatalystSqlParser.parsePlan( "SELECT percent_rank(a) FILTER (WHERE c > 1) OVER () FROM TaBlE2"), - expectedErrorClass = "INVALID_SQL_SYNTAX.FUNCTION_WITH_UNSUPPORTED_SYNTAX", + expectedErrorCondition = "INVALID_SQL_SYNTAX.FUNCTION_WITH_UNSUPPORTED_SYNTAX", expectedMessageParameters = Map( "prettyName" -> toSQLId("percent_rank"), "syntax" -> toSQLStmt("FILTER CLAUSE")), @@ -262,7 +263,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { } test("window specification error") { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = CatalystSqlParser.parsePlan( """ |WITH sample_data AS ( @@ -274,17 +275,17 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { |FROM sample_data |GROUP BY a, b; |""".stripMargin), - expectedErrorClass = "MISSING_WINDOW_SPECIFICATION", + expectedErrorCondition = "MISSING_WINDOW_SPECIFICATION", expectedMessageParameters = Map( "windowName" -> "b", "docroot" -> SPARK_DOC_ROOT)) } test("higher order function with filter predicate") { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( CatalystSqlParser.parsePlan("SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) " + "FILTER (WHERE c > 1)"), - expectedErrorClass = "INVALID_SQL_SYNTAX.FUNCTION_WITH_UNSUPPORTED_SYNTAX", + expectedErrorCondition = "INVALID_SQL_SYNTAX.FUNCTION_WITH_UNSUPPORTED_SYNTAX", expectedMessageParameters = Map( "prettyName" -> toSQLId("aggregate"), "syntax" -> toSQLStmt("FILTER CLAUSE")), @@ -293,9 +294,9 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { } test("function don't support ignore nulls") { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( CatalystSqlParser.parsePlan("SELECT hex(a) IGNORE NULLS FROM TaBlE2"), - expectedErrorClass = "INVALID_SQL_SYNTAX.FUNCTION_WITH_UNSUPPORTED_SYNTAX", + expectedErrorCondition = "INVALID_SQL_SYNTAX.FUNCTION_WITH_UNSUPPORTED_SYNTAX", expectedMessageParameters = Map( "prettyName" -> toSQLId("hex"), "syntax" -> toSQLStmt("IGNORE NULLS")), @@ -303,9 +304,9 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { } test("some window function don't support ignore nulls") { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( CatalystSqlParser.parsePlan("SELECT percent_rank(a) IGNORE NULLS FROM TaBlE2"), - expectedErrorClass = "INVALID_SQL_SYNTAX.FUNCTION_WITH_UNSUPPORTED_SYNTAX", + expectedErrorCondition = "INVALID_SQL_SYNTAX.FUNCTION_WITH_UNSUPPORTED_SYNTAX", expectedMessageParameters = Map( "prettyName" -> toSQLId("percent_rank"), "syntax" -> toSQLStmt("IGNORE NULLS")), @@ -313,9 +314,9 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { } test("aggregate function don't support ignore nulls") { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( CatalystSqlParser.parsePlan("SELECT count(a) IGNORE NULLS FROM TaBlE2"), - expectedErrorClass = "INVALID_SQL_SYNTAX.FUNCTION_WITH_UNSUPPORTED_SYNTAX", + expectedErrorCondition = "INVALID_SQL_SYNTAX.FUNCTION_WITH_UNSUPPORTED_SYNTAX", expectedMessageParameters = Map( "prettyName" -> toSQLId("count"), "syntax" -> toSQLStmt("IGNORE NULLS")), @@ -323,10 +324,10 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { } test("higher order function don't support ignore nulls") { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( CatalystSqlParser.parsePlan( "SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) IGNORE NULLS"), - expectedErrorClass = "INVALID_SQL_SYNTAX.FUNCTION_WITH_UNSUPPORTED_SYNTAX", + expectedErrorCondition = "INVALID_SQL_SYNTAX.FUNCTION_WITH_UNSUPPORTED_SYNTAX", expectedMessageParameters = Map( "prettyName" -> toSQLId("aggregate"), "syntax" -> toSQLStmt("IGNORE NULLS")), @@ -334,11 +335,11 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { "aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) IGNORE NULLS", 7, 68))) } - errorClassTest( + errorConditionTest( name = "nested aggregate functions", testRelation.groupBy($"a")( Max(Count(Literal(1)).toAggregateExpression()).toAggregateExpression()), - errorClass = "NESTED_AGGREGATE_FUNCTION", + condition = "NESTED_AGGREGATE_FUNCTION", messageParameters = Map.empty ) @@ -353,7 +354,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { SpecifiedWindowFrame(RangeFrame, Literal(1), Literal(2)))).as("window")), "Cannot specify window frame for lead function" :: Nil) - errorClassTest( + errorConditionTest( "the offset of nth_value window function is negative or zero", testRelation2.select( WindowExpression( @@ -362,14 +363,14 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { UnresolvedAttribute("a") :: Nil, SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, SpecifiedWindowFrame(RowFrame, Literal(0), Literal(0)))).as("window")), - errorClass = "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE", + condition = "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE", messageParameters = Map( "sqlExpr" -> "\"nth_value(b, 0)\"", "exprName" -> "offset", "valueRange" -> "(0, 9223372036854775807]", "currentValue" -> "0L")) - errorClassTest( + errorConditionTest( "the offset of nth_value window function is not int literal", testRelation2.select( WindowExpression( @@ -378,7 +379,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { UnresolvedAttribute("a") :: Nil, SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, SpecifiedWindowFrame(RowFrame, Literal(0), Literal(0)))).as("window")), - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", messageParameters = Map( "sqlExpr" -> "\"nth_value(b, true)\"", "paramIndex" -> "second", @@ -386,7 +387,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { "inputType" -> "\"BOOLEAN\"", "requiredType" -> "\"INT\"")) - errorClassTest( + errorConditionTest( "the buckets of ntile window function is not foldable", testRelation2.select( WindowExpression( @@ -395,7 +396,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { UnresolvedAttribute("a") :: Nil, SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, UnspecifiedFrame)).as("window")), - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", messageParameters = Map( "sqlExpr" -> "\"ntile(99.9)\"", "paramIndex" -> "first", @@ -404,7 +405,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { "requiredType" -> "\"INT\"")) - errorClassTest( + errorConditionTest( "the buckets of ntile window function is not int literal", testRelation2.select( WindowExpression( @@ -413,20 +414,20 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { UnresolvedAttribute("a") :: Nil, SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, UnspecifiedFrame)).as("window")), - errorClass = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", messageParameters = Map( "sqlExpr" -> "\"ntile(b)\"", "inputName" -> "`buckets`", "inputExpr" -> "\"b\"", "inputType" -> "\"INT\"")) - errorClassTest( + errorConditionTest( "unresolved attributes", testRelation.select($"abcd"), "UNRESOLVED_COLUMN.WITH_SUGGESTION", Map("objectName" -> "`abcd`", "proposal" -> "`a`")) - errorClassTest( + errorConditionTest( "unresolved attributes with a generated name", testRelation2.groupBy($"a")(max($"b")) .where(sum($"b") > 0) @@ -434,41 +435,41 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { "UNRESOLVED_COLUMN.WITH_SUGGESTION", Map("objectName" -> "`havingCondition`", "proposal" -> "`max(b)`")) - errorClassTest( + errorConditionTest( "unresolved star expansion in max", testRelation2.groupBy($"a")(sum(UnresolvedStar(None))), - errorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", + condition = "INVALID_USAGE_OF_STAR_OR_REGEX", messageParameters = Map("elem" -> "'*'", "prettyName" -> "expression `sum`") ) - errorClassTest( + errorConditionTest( "sorting by unsupported column types", mapRelation.orderBy($"map".asc), - errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + condition = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", messageParameters = Map( "sqlExpr" -> "\"map ASC NULLS FIRST\"", "functionName" -> "`sortorder`", "dataType" -> "\"MAP\"")) - errorClassTest( + errorConditionTest( "sorting by attributes are not from grouping expressions", testRelation2.groupBy($"a", $"c")($"a", $"c", count($"a").as("a3")).orderBy($"b".asc), "UNRESOLVED_COLUMN.WITH_SUGGESTION", Map("objectName" -> "`b`", "proposal" -> "`a`, `c`, `a3`")) - errorClassTest( + errorConditionTest( "non-boolean filters", testRelation.where(Literal(1)), - errorClass = "DATATYPE_MISMATCH.FILTER_NOT_BOOLEAN", + condition = "DATATYPE_MISMATCH.FILTER_NOT_BOOLEAN", messageParameters = Map("sqlExpr" -> "\"1\"", "filter" -> "\"1\"", "type" -> "\"INT\"")) - errorClassTest( + errorConditionTest( "non-boolean join conditions", testRelation.join(testRelation, condition = Some(Literal(1))), - errorClass = "JOIN_CONDITION_IS_NOT_BOOLEAN_TYPE", + condition = "JOIN_CONDITION_IS_NOT_BOOLEAN_TYPE", messageParameters = Map("joinCondition" -> "\"1\"", "conditionType" -> "\"INT\"")) - errorClassTest( + errorConditionTest( "missing group by", testRelation2.groupBy($"a")($"b"), "MISSING_AGGREGATION", @@ -477,27 +478,27 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { "expressionAnyValue" -> "\"any_value(b)\"") ) - errorClassTest( + errorConditionTest( "ambiguous field", nestedRelation.select($"top.duplicateField"), - errorClass = "AMBIGUOUS_REFERENCE_TO_FIELDS", + condition = "AMBIGUOUS_REFERENCE_TO_FIELDS", messageParameters = Map( "field" -> "`duplicateField`", "count" -> "2"), caseSensitive = false ) - errorClassTest( + errorConditionTest( "ambiguous field due to case insensitivity", nestedRelation.select($"top.differentCase"), - errorClass = "AMBIGUOUS_REFERENCE_TO_FIELDS", + condition = "AMBIGUOUS_REFERENCE_TO_FIELDS", messageParameters = Map( "field" -> "`differentCase`", "count" -> "2"), caseSensitive = false ) - errorClassTest( + errorConditionTest( "missing field", nestedRelation2.select($"top.c"), "FIELD_NOT_FOUND", @@ -509,7 +510,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { val analyzer = getAnalyzer analyzer.checkAnalysis(analyzer.execute(UnresolvedTestPlan())) }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "Found the unresolved operator: 'UnresolvedTestPlan")) errorTest( @@ -560,14 +561,14 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { testRelation3.except(testRelation4, isAll = false), "except" :: "compatible column types" :: "map" :: "decimal" :: Nil) - errorClassTest( + errorConditionTest( "SPARK-9955: correct error message for aggregate", // When parse SQL string, we will wrap aggregate expressions with UnresolvedAlias. testRelation2.where($"bad_column" > 1).groupBy($"a")(UnresolvedAlias(max($"b"))), "UNRESOLVED_COLUMN.WITH_SUGGESTION", Map("objectName" -> "`bad_column`", "proposal" -> "`a`, `c`, `d`, `b`, `e`")) - errorClassTest( + errorConditionTest( "slide duration greater than window in time window", testRelation2.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "2 second", "0 second").as("window")), @@ -582,7 +583,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { ) ) - errorClassTest( + errorConditionTest( "start time greater than slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 minute").as("window")), @@ -597,7 +598,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { ) ) - errorClassTest( + errorConditionTest( "start time equal to slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 second").as("window")), @@ -612,7 +613,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { ) ) - errorClassTest( + errorConditionTest( "SPARK-21590: absolute value of start time greater than slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-1 minute").as("window")), @@ -627,7 +628,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { ) ) - errorClassTest( + errorConditionTest( "SPARK-21590: absolute value of start time equal to slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-1 second").as("window")), @@ -642,7 +643,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { ) ) - errorClassTest( + errorConditionTest( "negative window duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "-1 second", "1 second", "0 second").as("window")), @@ -655,7 +656,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { ) ) - errorClassTest( + errorConditionTest( "zero window duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "0 second", "1 second", "0 second").as("window")), @@ -668,7 +669,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { ) ) - errorClassTest( + errorConditionTest( "negative slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "-1 second", "0 second").as("window")), @@ -681,7 +682,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { ) ) - errorClassTest( + errorConditionTest( "zero slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "0 second", "0 second").as("window")), @@ -733,7 +734,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { "The generator is not supported: outside the SELECT clause, found: Sort" :: Nil ) - errorClassTest( + errorConditionTest( "an evaluated limit class must not be string", testRelation.limit(Literal(UTF8String.fromString("abc"), StringType)), "INVALID_LIMIT_LIKE_EXPRESSION.DATA_TYPE", @@ -744,7 +745,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { ) ) - errorClassTest( + errorConditionTest( "an evaluated limit class must not be long", testRelation.limit(Literal(10L, LongType)), "INVALID_LIMIT_LIKE_EXPRESSION.DATA_TYPE", @@ -755,7 +756,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { ) ) - errorClassTest( + errorConditionTest( "an evaluated limit class must not be null", testRelation.limit(Literal(null, IntegerType)), "INVALID_LIMIT_LIKE_EXPRESSION.IS_NULL", @@ -765,7 +766,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { ) ) - errorClassTest( + errorConditionTest( "num_rows in limit clause must be equal to or greater than 0", listRelation.limit(-1), "INVALID_LIMIT_LIKE_EXPRESSION.IS_NEGATIVE", @@ -776,7 +777,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { ) ) - errorClassTest( + errorConditionTest( "an evaluated offset class must not be string", testRelation.offset(Literal(UTF8String.fromString("abc"), StringType)), "INVALID_LIMIT_LIKE_EXPRESSION.DATA_TYPE", @@ -787,7 +788,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { ) ) - errorClassTest( + errorConditionTest( "an evaluated offset class must not be long", testRelation.offset(Literal(10L, LongType)), "INVALID_LIMIT_LIKE_EXPRESSION.DATA_TYPE", @@ -798,7 +799,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { ) ) - errorClassTest( + errorConditionTest( "an evaluated offset class must not be null", testRelation.offset(Literal(null, IntegerType)), "INVALID_LIMIT_LIKE_EXPRESSION.IS_NULL", @@ -808,7 +809,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { ) ) - errorClassTest( + errorConditionTest( "num_rows in offset clause must be equal to or greater than 0", testRelation.offset(-1), "INVALID_LIMIT_LIKE_EXPRESSION.IS_NEGATIVE", @@ -819,7 +820,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { ) ) - errorClassTest( + errorConditionTest( "the sum of num_rows in limit clause and num_rows in offset clause less than Int.MaxValue", testRelation.offset(Literal(2000000000, IntegerType)).limit(Literal(1000000000, IntegerType)), "SUM_OF_LIMIT_AND_OFFSET_EXCEEDS_MAX_INT", @@ -833,14 +834,14 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { """"explode(array(min(a)))", "explode(array(max(a)))"""" :: Nil ) - errorClassTest( + errorConditionTest( "EXEC IMMEDIATE - nested execute immediate not allowed", CatalystSqlParser.parsePlan("EXECUTE IMMEDIATE 'EXECUTE IMMEDIATE \\\'SELECT 42\\\''"), "NESTED_EXECUTE_IMMEDIATE", Map( "sqlString" -> "EXECUTE IMMEDIATE 'SELECT 42'")) - errorClassTest( + errorConditionTest( "EXEC IMMEDIATE - both positional and named used", CatalystSqlParser.parsePlan("EXECUTE IMMEDIATE 'SELECT 42 where ? = :first'" + " USING 1, 2 as first"), @@ -853,9 +854,9 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { scala.util.Right(UnresolvedAttribute("testVarA")), Seq(UnresolvedAttribute("testVarA"))) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = execImmediatePlan, - expectedErrorClass = "INVALID_VARIABLE_TYPE_FOR_QUERY_EXECUTE_IMMEDIATE", + expectedErrorCondition = "INVALID_VARIABLE_TYPE_FOR_QUERY_EXECUTE_IMMEDIATE", expectedMessageParameters = Map( "varType" -> "\"INT\"" )) @@ -867,9 +868,9 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { scala.util.Right(UnresolvedAttribute("testVarNull")), Seq(UnresolvedAttribute("testVarNull"))) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = execImmediatePlan, - expectedErrorClass = "NULL_QUERY_STRING_EXECUTE_IMMEDIATE", + expectedErrorCondition = "NULL_QUERY_STRING_EXECUTE_IMMEDIATE", expectedMessageParameters = Map("varName" -> "`testVarNull`")) } @@ -880,9 +881,9 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { scala.util.Left("SELECT ?"), Seq.empty) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = execImmediatePlan, - expectedErrorClass = "UNSUPPORTED_EXPR_FOR_PARAMETER", + expectedErrorCondition = "UNSUPPORTED_EXPR_FOR_PARAMETER", expectedMessageParameters = Map( "invalidExprSql" -> "\"nanvl(1, 1)\"" )) @@ -894,9 +895,9 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { scala.util.Left("SELECT :first"), Seq.empty) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = execImmediateSetVariablePlan, - expectedErrorClass = "ALL_PARAMETERS_MUST_BE_NAMED", + expectedErrorCondition = "ALL_PARAMETERS_MUST_BE_NAMED", expectedMessageParameters = Map( "exprs" -> "\"2\", \"3\"" )) @@ -908,9 +909,9 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { scala.util.Left("SET VAR testVarA = 1"), Seq(UnresolvedAttribute("testVarA"))) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = execImmediateSetVariablePlan, - expectedErrorClass = "INVALID_STATEMENT_FOR_EXECUTE_INTO", + expectedErrorCondition = "INVALID_STATEMENT_FOR_EXECUTE_INTO", expectedMessageParameters = Map( "sqlString" -> "SET VAR TESTVARA = 1" )) @@ -931,9 +932,9 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { assert(plan.resolved) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = plan, - expectedErrorClass = "MISSING_ATTRIBUTES.RESOLVED_ATTRIBUTE_APPEAR_IN_OPERATION", + expectedErrorCondition = "MISSING_ATTRIBUTES.RESOLVED_ATTRIBUTE_APPEAR_IN_OPERATION", expectedMessageParameters = Map( "missingAttributes" -> "\"a\", \"c\"", "input" -> "\"a\"", @@ -949,7 +950,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { exception = intercept[SparkException] { SimpleAnalyzer.checkAnalysis(join) }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> """ |Failure when resolving conflicting references in Join: @@ -966,7 +967,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { exception = intercept[SparkException] { SimpleAnalyzer.checkAnalysis(intersect) }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> """ |Failure when resolving conflicting references in Intersect All: @@ -983,7 +984,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { exception = intercept[SparkException] { SimpleAnalyzer.checkAnalysis(except) }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> """ |Failure when resolving conflicting references in Except All: @@ -1003,7 +1004,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { exception = intercept[SparkException] { SimpleAnalyzer.checkAnalysis(asOfJoin) }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> """ |Failure when resolving conflicting references in AsOfJoin: @@ -1059,9 +1060,9 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { AttributeReference("a", IntegerType)(exprId = ExprId(2)), AttributeReference("b", IntegerType)(exprId = ExprId(1)))) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = plan, - expectedErrorClass = "NESTED_AGGREGATE_FUNCTION", + expectedErrorCondition = "NESTED_AGGREGATE_FUNCTION", expectedMessageParameters = Map.empty ) } @@ -1082,9 +1083,9 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { joinType = Cross, condition = Some($"b" === $"d")) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = plan2, - expectedErrorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + expectedErrorCondition = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", expectedMessageParameters = Map( "functionName" -> "`=`", "dataType" -> "\"MAP\"", @@ -1145,8 +1146,8 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val plan = Filter($"a" === UnresolvedFunction("max", Seq(b), true), LocalRelation(a, b)) - assertAnalysisErrorClass(plan, - expectedErrorClass = "INVALID_WHERE_CONDITION", + assertAnalysisErrorCondition(plan, + expectedErrorCondition = "INVALID_WHERE_CONDITION", expectedMessageParameters = Map( "condition" -> "\"(a = max(DISTINCT b))\"", "expressionList" -> "max(DISTINCT b)")) @@ -1160,8 +1161,8 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { Project( Alias(Literal(1), "x")() :: Nil, UnresolvedRelation(TableIdentifier("t", Option("nonexist"))))))) - assertAnalysisErrorClass(plan, - expectedErrorClass = "TABLE_OR_VIEW_NOT_FOUND", + assertAnalysisErrorCondition(plan, + expectedErrorCondition = "TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> "`nonexist`.`t`")) } @@ -1170,8 +1171,8 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { (Randn("a".attr), "\"randn(a)\"")).foreach { case (r, expectedArg) => val plan = Project(Seq(r.as("r")), testRelation) - assertAnalysisErrorClass(plan, - expectedErrorClass = "SEED_EXPRESSION_IS_UNFOLDABLE", + assertAnalysisErrorCondition(plan, + expectedErrorCondition = "SEED_EXPRESSION_IS_UNFOLDABLE", expectedMessageParameters = Map( "seedExpr" -> "\"a\"", "exprWithSeed" -> expectedArg), @@ -1184,8 +1185,8 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { Randn("a") -> ("\"randn(a)\"", "\"a\"", "\"STRING\"") ).foreach { case (r, (sqlExpr, inputSql, inputType)) => val plan = Project(Seq(r.as("r")), testRelation) - assertAnalysisErrorClass(plan, - expectedErrorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + assertAnalysisErrorCondition(plan, + expectedErrorCondition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", expectedMessageParameters = Map( "sqlExpr" -> sqlExpr, "paramIndex" -> "first", @@ -1208,9 +1209,9 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { t.as("t2"))) ) :: Nil, sum($"c2").as("sum") :: Nil, t.as("t1")) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( plan, - expectedErrorClass = + expectedErrorCondition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.MUST_AGGREGATE_CORRELATED_SCALAR_SUBQUERY", expectedMessageParameters = Map.empty) } @@ -1226,37 +1227,37 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { Filter($"t1.c1" === $"t2.c1", t.as("t2"))) ).as("sub") :: Nil, t.as("t1")) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( plan, - expectedErrorClass = + expectedErrorCondition = "SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION", expectedMessageParameters = Map("sqlExpr" -> "\"scalarsubquery(c1)\"")) } - errorClassTest( + errorConditionTest( "SPARK-34920: error code to error message", testRelation2.where($"bad_column" > 1).groupBy($"a")(UnresolvedAlias(max($"b"))), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", messageParameters = Map( "objectName" -> "`bad_column`", "proposal" -> "`a`, `c`, `d`, `b`, `e`")) - errorClassTest( + errorConditionTest( "SPARK-39783: backticks in error message for candidate column with dots", // This selects a column that does not exist, // the error message suggest the existing column with correct backticks testRelation6.select($"`the`.`id`"), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", messageParameters = Map( "objectName" -> "`the`.`id`", "proposal" -> "`the.id`")) - errorClassTest( + errorConditionTest( "SPARK-39783: backticks in error message for candidate struct column", // This selects a column that does not exist, // the error message suggest the existing column with correct backticks nestedRelation2.select($"`top.aField`"), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", messageParameters = Map( "objectName" -> "`top.aField`", "proposal" -> "`top`")) @@ -1272,7 +1273,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { exception = intercept[SparkException] { SimpleAnalyzer.checkAnalysis(plan) }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "Hint not found: `some_random_hint_that_does_not_exist`")) // UnresolvedHint be removed by batch `Remove Unresolved Hints` @@ -1291,9 +1292,9 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { "Scalar subquery must return only one column, but got 2" :: Nil) // t2.* cannot be resolved and the error should be the initial analysis exception. - assertAnalysisErrorClass( + assertAnalysisErrorCondition( Project(ScalarSubquery(t0.select(star("t2"))).as("sub") :: Nil, t1), - expectedErrorClass = "CANNOT_RESOLVE_STAR_EXPAND", + expectedErrorCondition = "CANNOT_RESOLVE_STAR_EXPAND", expectedMessageParameters = Map("targetString" -> "`t2`", "columns" -> "") ) } @@ -1306,70 +1307,70 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { val t2 = LocalRelation(b, c).as("t2") // SELECT * FROM t1 WHERE a = (SELECT sum(c) FROM t2 WHERE t1.* = t2.b) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( Filter(EqualTo(a, ScalarSubquery(t2.select(sum(c)).where(star("t1") === b))), t1), - expectedErrorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", + expectedErrorCondition = "INVALID_USAGE_OF_STAR_OR_REGEX", expectedMessageParameters = Map("elem" -> "'*'", "prettyName" -> "expression `equalto`") ) // SELECT * FROM t1 JOIN t2 ON (EXISTS (SELECT 1 FROM t2 WHERE t1.* = b)) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( t1.join(t2, condition = Some(Exists(t2.select(1).where(star("t1") === b)))), - expectedErrorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", + expectedErrorCondition = "INVALID_USAGE_OF_STAR_OR_REGEX", expectedMessageParameters = Map("elem" -> "'*'", "prettyName" -> "expression `equalto`") ) } test("SPARK-36488: Regular expression expansion should fail with a meaningful message") { withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "true") { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( testRelation.select(Divide(UnresolvedRegex(".?", None, false), "a")), - expectedErrorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", + expectedErrorCondition = "INVALID_USAGE_OF_STAR_OR_REGEX", expectedMessageParameters = Map( "elem" -> "regular expression '.?'", "prettyName" -> "expression `divide`") ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( testRelation.select( Divide(UnresolvedRegex(".?", None, false), UnresolvedRegex(".*", None, false))), - expectedErrorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", + expectedErrorCondition = "INVALID_USAGE_OF_STAR_OR_REGEX", expectedMessageParameters = Map( "elem" -> "regular expressions '.?', '.*'", "prettyName" -> "expression `divide`") ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( testRelation.select( Divide(UnresolvedRegex(".?", None, false), UnresolvedRegex(".?", None, false))), - expectedErrorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", + expectedErrorCondition = "INVALID_USAGE_OF_STAR_OR_REGEX", expectedMessageParameters = Map( "elem" -> "regular expression '.?'", "prettyName" -> "expression `divide`") ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( testRelation.select(Divide(UnresolvedStar(None), "a")), - expectedErrorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", + expectedErrorCondition = "INVALID_USAGE_OF_STAR_OR_REGEX", expectedMessageParameters = Map( "elem" -> "'*'", "prettyName" -> "expression `divide`") ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( testRelation.select(Divide(UnresolvedStar(None), UnresolvedStar(None))), - expectedErrorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", + expectedErrorCondition = "INVALID_USAGE_OF_STAR_OR_REGEX", expectedMessageParameters = Map( "elem" -> "'*'", "prettyName" -> "expression `divide`") ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( testRelation.select(Divide(UnresolvedStar(None), UnresolvedRegex(".?", None, false))), - expectedErrorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", + expectedErrorCondition = "INVALID_USAGE_OF_STAR_OR_REGEX", expectedMessageParameters = Map( "elem" -> "'*' and regular expression '.?'", "prettyName" -> "expression `divide`") ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( testRelation.select(Least(Seq(UnresolvedStar(None), UnresolvedRegex(".*", None, false), UnresolvedRegex(".?", None, false)))), - expectedErrorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", + expectedErrorCondition = "INVALID_USAGE_OF_STAR_OR_REGEX", expectedMessageParameters = Map( "elem" -> "'*' and regular expressions '.*', '.?'", "prettyName" -> "expression `least`") @@ -1377,7 +1378,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { } } - errorClassTest( + errorConditionTest( "SPARK-47572: Enforce Window partitionSpec is orderable", testRelation2.select( WindowExpression( @@ -1386,7 +1387,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { CreateMap(Literal("key") :: UnresolvedAttribute("a") :: Nil) :: Nil, SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, UnspecifiedFrame)).as("window")), - errorClass = "EXPRESSION_TYPE_IS_NOT_ORDERABLE", + condition = "EXPRESSION_TYPE_IS_NOT_ORDERABLE", messageParameters = Map( "expr" -> "\"_w0\"", "exprType" -> "\"MAP\"")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisExceptionPositionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisExceptionPositionSuite.scala index be256adbd8929..55f59f7a22574 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisExceptionPositionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisExceptionPositionSuite.scala @@ -48,7 +48,7 @@ class AnalysisExceptionPositionSuite extends AnalysisTest { verifyTableOrViewPosition("REFRESH TABLE unknown", "unknown") verifyTableOrViewPosition("SHOW COLUMNS FROM unknown", "unknown") // Special case where namespace is prepended to the table name. - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsePlan("SHOW COLUMNS FROM unknown IN db"), "TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> "`db`.`unknown`"), @@ -94,7 +94,7 @@ class AnalysisExceptionPositionSuite extends AnalysisTest { private def verifyPosition(sql: String, table: String): Unit = { val startPos = sql.indexOf(table) assert(startPos != -1) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsePlan(sql), "TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> s"`$table`"), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 1816c620414c9..e23a753dafe8c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -74,7 +74,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { None, CaseInsensitiveStringMap.empty()).analyze }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "Logical plan should not have output of char/varchar type.*\n"), matchPVals = true) @@ -112,7 +112,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), Project(testRelation.output, testRelation)) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( Project(Seq(UnresolvedAttribute("tBl.a")), SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), "UNRESOLVED_COLUMN.WITH_SUGGESTION", @@ -359,7 +359,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { val plan = Project(Alias(In(Literal(null), Seq(Literal(true), Literal(1))), "a")() :: Nil, LocalRelation() ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( plan, "DATATYPE_MISMATCH.DATA_DIFF_TYPES", Map( @@ -555,7 +555,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { assertAnalysisSuccess(rangeWithAliases(3 :: Nil, "a" :: Nil)) assertAnalysisSuccess(rangeWithAliases(1 :: 4 :: Nil, "b" :: Nil)) assertAnalysisSuccess(rangeWithAliases(2 :: 6 :: 2 :: Nil, "c" :: Nil)) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( rangeWithAliases(3 :: Nil, "a" :: "b" :: Nil), "NUM_TABLE_VALUE_ALIASES_MISMATCH", Map("funcName" -> "`range`", "aliasesNum" -> "2", "outColsNum" -> "1")) @@ -569,12 +569,12 @@ class AnalysisSuite extends AnalysisTest with Matchers { ).select(star()) } assertAnalysisSuccess(tableColumnsWithAliases("col1" :: "col2" :: "col3" :: "col4" :: Nil)) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( tableColumnsWithAliases("col1" :: Nil), "ASSIGNMENT_ARITY_MISMATCH", Map("numExpr" -> "1", "numTarget" -> "4") ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( tableColumnsWithAliases("col1" :: "col2" :: "col3" :: "col4" :: "col5" :: Nil), "ASSIGNMENT_ARITY_MISMATCH", Map("numExpr" -> "5", "numTarget" -> "4") @@ -591,12 +591,12 @@ class AnalysisSuite extends AnalysisTest with Matchers { ).select(star()) } assertAnalysisSuccess(tableColumnsWithAliases("col1" :: "col2" :: "col3" :: "col4" :: Nil)) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( tableColumnsWithAliases("col1" :: Nil), "ASSIGNMENT_ARITY_MISMATCH", Map("numExpr" -> "1", "numTarget" -> "4") ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( tableColumnsWithAliases("col1" :: "col2" :: "col3" :: "col4" :: "col5" :: Nil), "ASSIGNMENT_ARITY_MISMATCH", Map("numExpr" -> "5", "numTarget" -> "4") @@ -615,12 +615,12 @@ class AnalysisSuite extends AnalysisTest with Matchers { ).select(star()) } assertAnalysisSuccess(joinRelationWithAliases("col1" :: "col2" :: "col3" :: "col4" :: Nil)) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( joinRelationWithAliases("col1" :: Nil), "ASSIGNMENT_ARITY_MISMATCH", Map("numExpr" -> "1", "numTarget" -> "4") ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( joinRelationWithAliases("col1" :: "col2" :: "col3" :: "col4" :: "col5" :: Nil), "ASSIGNMENT_ARITY_MISMATCH", Map("numExpr" -> "5", "numTarget" -> "4") @@ -755,7 +755,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { test("SPARK-34741: Avoid ambiguous reference in MergeIntoTable") { val cond = $"a" > 1 - assertAnalysisErrorClass( + assertAnalysisErrorCondition( MergeIntoTable( testRelation, testRelation, @@ -794,7 +794,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("CTE with non-existing column alias") { - assertAnalysisErrorClass(parsePlan("WITH t(x) AS (SELECT 1) SELECT * FROM t WHERE y = 1"), + assertAnalysisErrorCondition(parsePlan("WITH t(x) AS (SELECT 1) SELECT * FROM t WHERE y = 1"), "UNRESOLVED_COLUMN.WITH_SUGGESTION", Map("objectName" -> "`y`", "proposal" -> "`x`"), Array(ExpectedContext("y", 46, 46)) @@ -802,7 +802,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("CTE with non-matching column alias") { - assertAnalysisErrorClass(parsePlan("WITH t(x, y) AS (SELECT 1) SELECT * FROM t WHERE x = 1"), + assertAnalysisErrorCondition( + parsePlan("WITH t(x, y) AS (SELECT 1) SELECT * FROM t WHERE x = 1"), "ASSIGNMENT_ARITY_MISMATCH", Map("numExpr" -> "2", "numTarget" -> "1"), Array(ExpectedContext("t(x, y) AS (SELECT 1)", 5, 25)) @@ -810,7 +811,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-28251: Insert into non-existing table error message is user friendly") { - assertAnalysisErrorClass(parsePlan("INSERT INTO test VALUES (1)"), + assertAnalysisErrorCondition(parsePlan("INSERT INTO test VALUES (1)"), "TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> "`test`"), Array(ExpectedContext("test", 12, 15))) } @@ -826,9 +827,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { // Bad name assert(!CollectMetrics("", sum :: Nil, testRelation, 0).resolved) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( CollectMetrics("", sum :: Nil, testRelation, 0), - expectedErrorClass = "INVALID_OBSERVED_METRICS.MISSING_NAME", + expectedErrorCondition = "INVALID_OBSERVED_METRICS.MISSING_NAME", expectedMessageParameters = Map( "operator" -> "'CollectMetrics , [sum(a#x) AS sum#xL], 0\n+- LocalRelation , [a#x]\n") @@ -853,37 +854,38 @@ class AnalysisSuite extends AnalysisTest with Matchers { ) // Unwrapped attribute - assertAnalysisErrorClass( + assertAnalysisErrorCondition( CollectMetrics("event", a :: Nil, testRelation, 0), - expectedErrorClass = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE", + expectedErrorCondition = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE", expectedMessageParameters = Map("expr" -> "\"a\"") ) // Unwrapped non-deterministic expression - assertAnalysisErrorClass( + assertAnalysisErrorCondition( CollectMetrics("event", Rand(10).as("rnd") :: Nil, testRelation, 0), - expectedErrorClass = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_NON_DETERMINISTIC", + expectedErrorCondition = + "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_NON_DETERMINISTIC", expectedMessageParameters = Map("expr" -> "\"rand(10) AS rnd\"") ) // Distinct aggregate - assertAnalysisErrorClass( + assertAnalysisErrorCondition( CollectMetrics( "event", Sum(a).toAggregateExpression(isDistinct = true).as("sum") :: Nil, testRelation, 0), - expectedErrorClass = + expectedErrorCondition = "INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_DISTINCT_UNSUPPORTED", expectedMessageParameters = Map("expr" -> "\"sum(DISTINCT a) AS sum\"") ) // Nested aggregate - assertAnalysisErrorClass( + assertAnalysisErrorCondition( CollectMetrics( "event", Sum(Sum(a).toAggregateExpression()).toAggregateExpression().as("sum") :: Nil, testRelation, 0), - expectedErrorClass = "INVALID_OBSERVED_METRICS.NESTED_AGGREGATES_UNSUPPORTED", + expectedErrorCondition = "INVALID_OBSERVED_METRICS.NESTED_AGGREGATES_UNSUPPORTED", expectedMessageParameters = Map("expr" -> "\"sum(sum(a)) AS sum\"") ) @@ -892,9 +894,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { RowNumber(), WindowSpecDefinition(Nil, a.asc :: Nil, SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( CollectMetrics("event", windowExpr.as("rn") :: Nil, testRelation, 0), - expectedErrorClass = "INVALID_OBSERVED_METRICS.WINDOW_EXPRESSIONS_UNSUPPORTED", + expectedErrorCondition = "INVALID_OBSERVED_METRICS.WINDOW_EXPRESSIONS_UNSUPPORTED", expectedMessageParameters = Map( "expr" -> """ @@ -915,22 +917,22 @@ class AnalysisSuite extends AnalysisTest with Matchers { CollectMetrics("evt1", count :: Nil, testRelation, 0) :: Nil)) // Same children, structurally different metrics - fail - assertAnalysisErrorClass( + assertAnalysisErrorCondition( Union( CollectMetrics("evt1", count :: Nil, testRelation, 0) :: CollectMetrics("evt1", sum :: Nil, testRelation, 1) :: Nil), - expectedErrorClass = "DUPLICATED_METRICS_NAME", + expectedErrorCondition = "DUPLICATED_METRICS_NAME", expectedMessageParameters = Map("metricName" -> "evt1") ) // Different children, same metrics - fail val b = $"b".string val tblB = LocalRelation(b) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( Union( CollectMetrics("evt1", count :: Nil, testRelation, 0) :: CollectMetrics("evt1", count :: Nil, tblB, 1) :: Nil), - expectedErrorClass = "DUPLICATED_METRICS_NAME", + expectedErrorCondition = "DUPLICATED_METRICS_NAME", expectedMessageParameters = Map("metricName" -> "evt1") ) @@ -939,9 +941,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { val query = Project( b :: ScalarSubquery(subquery, Nil).as("sum") :: Nil, CollectMetrics("evt1", count :: Nil, tblB, 1)) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( query, - expectedErrorClass = "DUPLICATED_METRICS_NAME", + expectedErrorCondition = "DUPLICATED_METRICS_NAME", expectedMessageParameters = Map("metricName" -> "evt1") ) @@ -949,9 +951,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { val sumWithFilter = sum.transform { case a: AggregateExpression => a.copy(filter = Some(true)) }.asInstanceOf[NamedExpression] - assertAnalysisErrorClass( + assertAnalysisErrorCondition( CollectMetrics("evt1", sumWithFilter :: Nil, testRelation, 0), - expectedErrorClass = + expectedErrorCondition = "INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_FILTER_UNSUPPORTED", expectedMessageParameters = Map("expr" -> "\"sum(a) FILTER (WHERE true) AS sum\"") ) @@ -1062,9 +1064,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { AttributeReference("c", IntegerType)(), AttributeReference("d", TimestampType)()) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( Union(firstTable, secondTable), - expectedErrorClass = "INCOMPATIBLE_COLUMN_TYPE", + expectedErrorCondition = "INCOMPATIBLE_COLUMN_TYPE", expectedMessageParameters = Map( "tableOrdinalNumber" -> "second", "columnOrdinalNumber" -> "second", @@ -1074,9 +1076,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { "dataType1" -> "\"TIMESTAMP\"") ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( Union(firstTable, thirdTable), - expectedErrorClass = "INCOMPATIBLE_COLUMN_TYPE", + expectedErrorCondition = "INCOMPATIBLE_COLUMN_TYPE", expectedMessageParameters = Map( "tableOrdinalNumber" -> "second", "columnOrdinalNumber" -> "third", @@ -1086,9 +1088,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { "dataType1" -> "\"TIMESTAMP\"") ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( Union(firstTable, fourthTable), - expectedErrorClass = "INCOMPATIBLE_COLUMN_TYPE", + expectedErrorCondition = "INCOMPATIBLE_COLUMN_TYPE", expectedMessageParameters = Map( "tableOrdinalNumber" -> "second", "columnOrdinalNumber" -> "4th", @@ -1098,9 +1100,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { "dataType1" -> "\"TIMESTAMP\"") ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( Except(firstTable, secondTable, isAll = false), - expectedErrorClass = "INCOMPATIBLE_COLUMN_TYPE", + expectedErrorCondition = "INCOMPATIBLE_COLUMN_TYPE", expectedMessageParameters = Map( "tableOrdinalNumber" -> "second", "columnOrdinalNumber" -> "second", @@ -1110,9 +1112,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { "dataType1" -> "\"TIMESTAMP\"") ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( Intersect(firstTable, secondTable, isAll = false), - expectedErrorClass = "INCOMPATIBLE_COLUMN_TYPE", + expectedErrorCondition = "INCOMPATIBLE_COLUMN_TYPE", expectedMessageParameters = Map( "tableOrdinalNumber" -> "second", "columnOrdinalNumber" -> "second", @@ -1124,21 +1126,21 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-31975: Throw user facing error when use WindowFunction directly") { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = testRelation2.select(RowNumber()), - expectedErrorClass = "WINDOW_FUNCTION_WITHOUT_OVER_CLAUSE", + expectedErrorCondition = "WINDOW_FUNCTION_WITHOUT_OVER_CLAUSE", expectedMessageParameters = Map("funcName" -> "\"row_number()\"") ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = testRelation2.select(Sum(RowNumber())), - expectedErrorClass = "WINDOW_FUNCTION_WITHOUT_OVER_CLAUSE", + expectedErrorCondition = "WINDOW_FUNCTION_WITHOUT_OVER_CLAUSE", expectedMessageParameters = Map("funcName" -> "\"row_number()\"") ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = testRelation2.select(RowNumber() + 1), - expectedErrorClass = "WINDOW_FUNCTION_WITHOUT_OVER_CLAUSE", + expectedErrorCondition = "WINDOW_FUNCTION_WITHOUT_OVER_CLAUSE", expectedMessageParameters = Map("funcName" -> "\"row_number()\"") ) } @@ -1297,7 +1299,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { | ORDER BY grouping__id > 0 """.stripMargin), false) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsePlan( """ |SELECT grouping__id FROM ( @@ -1328,7 +1330,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { |ORDER BY c.x |""".stripMargin)) - assertAnalysisErrorClass(parsePlan( + assertAnalysisErrorCondition(parsePlan( """ |SELECT c.x |FROM VALUES NAMED_STRUCT('x', 'A', 'y', 1), NAMED_STRUCT('x', 'A', 'y', 2) AS t(c) @@ -1342,7 +1344,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-38118: Func(wrong_type) in the HAVING clause should throw data mismatch error") { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = parsePlan( s""" |WITH t as (SELECT true c) @@ -1350,7 +1352,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { |FROM t |GROUP BY t.c |HAVING mean(t.c) > 0d""".stripMargin), - expectedErrorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + expectedErrorCondition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", expectedMessageParameters = Map( "sqlExpr" -> "\"mean(c)\"", "paramIndex" -> "first", @@ -1361,7 +1363,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { caseSensitive = false ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = parsePlan( s""" |WITH t as (SELECT true c, false d) @@ -1369,7 +1371,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { |FROM t |GROUP BY t.c, t.d |HAVING mean(c) > 0d""".stripMargin), - expectedErrorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + expectedErrorCondition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", expectedMessageParameters = Map( "sqlExpr" -> "\"mean(c)\"", "paramIndex" -> "first", @@ -1379,7 +1381,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { queryContext = Array(ExpectedContext("mean(c)", 91, 97)), caseSensitive = false) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = parsePlan( s""" |WITH t as (SELECT true c) @@ -1387,7 +1389,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { |FROM t |GROUP BY t.c |HAVING abs(t.c) > 0d""".stripMargin), - expectedErrorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + expectedErrorCondition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", expectedMessageParameters = Map( "sqlExpr" -> "\"abs(c)\"", "paramIndex" -> "first", @@ -1399,7 +1401,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { caseSensitive = false ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = parsePlan( s""" |WITH t as (SELECT true c, false d) @@ -1407,7 +1409,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { |FROM t |GROUP BY t.c, t.d |HAVING abs(c) > 0d""".stripMargin), - expectedErrorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + expectedErrorCondition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", expectedMessageParameters = Map( "sqlExpr" -> "\"abs(c)\"", "paramIndex" -> "first", @@ -1421,7 +1423,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-39354: should be [TABLE_OR_VIEW_NOT_FOUND]") { - assertAnalysisErrorClass(parsePlan( + assertAnalysisErrorCondition(parsePlan( s""" |WITH t1 as (SELECT 1 user_id, CAST("2022-06-02" AS DATE) dt) |SELECT * @@ -1531,13 +1533,13 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-41489: type of filter expression should be a bool") { - assertAnalysisErrorClass(parsePlan( + assertAnalysisErrorCondition(parsePlan( s""" |WITH t1 as (SELECT 1 user_id) |SELECT * |FROM t1 |WHERE 'true'""".stripMargin), - expectedErrorClass = "DATATYPE_MISMATCH.FILTER_NOT_BOOLEAN", + expectedErrorCondition = "DATATYPE_MISMATCH.FILTER_NOT_BOOLEAN", expectedMessageParameters = Map( "sqlExpr" -> "\"true\"", "filter" -> "\"true\"", "type" -> "\"STRING\"") , diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index dc95198574fb4..33b9fb488c94f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -178,9 +178,9 @@ trait AnalysisTest extends PlanTest { } } - protected def assertAnalysisErrorClass( + protected def assertAnalysisErrorCondition( inputPlan: LogicalPlan, - expectedErrorClass: String, + expectedErrorCondition: String, expectedMessageParameters: Map[String, String], queryContext: Array[ExpectedContext] = Array.empty, caseSensitive: Boolean = true): Unit = { @@ -191,7 +191,7 @@ trait AnalysisTest extends PlanTest { } checkError( exception = e, - errorClass = expectedErrorClass, + condition = expectedErrorCondition, parameters = expectedMessageParameters, queryContext = queryContext ) @@ -199,14 +199,13 @@ trait AnalysisTest extends PlanTest { } protected def interceptParseException(parser: String => Any)( - sqlCommand: String, messages: String*)( - errorClass: Option[String] = None): Unit = { + sqlCommand: String, messages: String*)(condition: Option[String] = None): Unit = { val e = parseException(parser)(sqlCommand) messages.foreach { message => assert(e.message.contains(message)) } - if (errorClass.isDefined) { - assert(e.getErrorClass == errorClass.get) + if (condition.isDefined) { + assert(e.getErrorClass == condition.get) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala index c9e37e255ab44..6b034d3dbee09 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala @@ -41,8 +41,8 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { ignoreIfExists = false) assert(!plan.resolved) - assertAnalysisErrorClass(plan, - expectedErrorClass = "UNSUPPORTED_FEATURE.PARTITION_WITH_NESTED_COLUMN_IS_UNSUPPORTED", + assertAnalysisErrorCondition(plan, + expectedErrorCondition = "UNSUPPORTED_FEATURE.PARTITION_WITH_NESTED_COLUMN_IS_UNSUPPORTED", expectedMessageParameters = Map("cols" -> "`does_not_exist`")) } @@ -56,8 +56,8 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { ignoreIfExists = false) assert(!plan.resolved) - assertAnalysisErrorClass(plan, - expectedErrorClass = "UNSUPPORTED_FEATURE.PARTITION_WITH_NESTED_COLUMN_IS_UNSUPPORTED", + assertAnalysisErrorCondition(plan, + expectedErrorCondition = "UNSUPPORTED_FEATURE.PARTITION_WITH_NESTED_COLUMN_IS_UNSUPPORTED", expectedMessageParameters = Map("cols" -> "`does_not_exist`.`z`")) } @@ -71,8 +71,8 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { ignoreIfExists = false) assert(!plan.resolved) - assertAnalysisErrorClass(plan, - expectedErrorClass = "UNSUPPORTED_FEATURE.PARTITION_WITH_NESTED_COLUMN_IS_UNSUPPORTED", + assertAnalysisErrorCondition(plan, + expectedErrorCondition = "UNSUPPORTED_FEATURE.PARTITION_WITH_NESTED_COLUMN_IS_UNSUPPORTED", expectedMessageParameters = Map("cols" -> "`point`.`z`")) } @@ -86,8 +86,8 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { ignoreIfExists = false) assert(!plan.resolved) - assertAnalysisErrorClass(plan, - expectedErrorClass = "UNSUPPORTED_FEATURE.PARTITION_WITH_NESTED_COLUMN_IS_UNSUPPORTED", + assertAnalysisErrorCondition(plan, + expectedErrorCondition = "UNSUPPORTED_FEATURE.PARTITION_WITH_NESTED_COLUMN_IS_UNSUPPORTED", expectedMessageParameters = Map("cols" -> "`does_not_exist`, `point`.`z`")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 4367cbbd24a89..95e118a30771c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -53,7 +53,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer expr: Expression, messageParameters: Map[String, String]): Unit = { checkError( exception = analysisException(expr), - errorClass = "DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES", parameters = messageParameters) } @@ -61,7 +61,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer expr: Expression, messageParameters: Map[String, String]): Unit = { checkError( exception = analysisException(expr), - errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + condition = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", parameters = messageParameters) } @@ -69,7 +69,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer expr: Expression, messageParameters: Map[String, String]): Unit = { checkError( exception = analysisException(expr), - errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", parameters = messageParameters) } @@ -77,14 +77,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer expr: Expression, messageParameters: Map[String, String]): Unit = { checkError( exception = analysisException(expr), - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = messageParameters) } private def assertForWrongType(expr: Expression, messageParameters: Map[String, String]): Unit = { checkError( exception = analysisException(expr), - errorClass = "DATATYPE_MISMATCH.BINARY_OP_WRONG_TYPE", + condition = "DATATYPE_MISMATCH.BINARY_OP_WRONG_TYPE", parameters = messageParameters) } @@ -93,7 +93,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer exception = intercept[AnalysisException] { assertSuccess(BitwiseNot($"stringField")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"~stringField\"", "paramIndex" -> ordinalNumber(0), @@ -426,7 +426,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer exception = intercept[AnalysisException] { assertSuccess(Sum($"booleanField")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"sum(booleanField)\"", "paramIndex" -> ordinalNumber(0), @@ -437,7 +437,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer exception = intercept[AnalysisException] { assertSuccess(Average($"booleanField")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"avg(booleanField)\"", "paramIndex" -> ordinalNumber(0), @@ -469,7 +469,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer exception = intercept[AnalysisException] { assertSuccess(coalesce) }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> toSQLId(coalesce.prettyName), "expectedNum" -> "> 0", @@ -481,7 +481,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer exception = intercept[AnalysisException] { assertSuccess(murmur3Hash) }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> toSQLId(murmur3Hash.prettyName), "expectedNum" -> "> 0", @@ -493,7 +493,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer exception = intercept[AnalysisException] { assertSuccess(xxHash64) }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> toSQLId(xxHash64.prettyName), "expectedNum" -> "> 0", @@ -504,7 +504,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer exception = intercept[AnalysisException] { assertSuccess(Explode($"intField")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"explode(intField)\"", "paramIndex" -> ordinalNumber(0), @@ -516,7 +516,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer exception = intercept[AnalysisException] { assertSuccess(PosExplode($"intField")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"posexplode(intField)\"", "paramIndex" -> ordinalNumber(0), @@ -529,7 +529,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer test("check types for CreateNamedStruct") { checkError( exception = analysisException(CreateNamedStruct(Seq("a", "b", 2.0))), - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> "`named_struct`", "expectedNum" -> "2n (n > 0)", @@ -538,21 +538,21 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer ) checkError( exception = analysisException(CreateNamedStruct(Seq(1, "a", "b", 2.0))), - errorClass = "DATATYPE_MISMATCH.CREATE_NAMED_STRUCT_WITHOUT_FOLDABLE_STRING", + condition = "DATATYPE_MISMATCH.CREATE_NAMED_STRUCT_WITHOUT_FOLDABLE_STRING", parameters = Map( "sqlExpr" -> "\"named_struct(1, a, b, 2.0)\"", "inputExprs" -> "[\"1\"]") ) checkError( exception = analysisException(CreateNamedStruct(Seq($"a".string.at(0), "a", "b", 2.0))), - errorClass = "DATATYPE_MISMATCH.CREATE_NAMED_STRUCT_WITHOUT_FOLDABLE_STRING", + condition = "DATATYPE_MISMATCH.CREATE_NAMED_STRUCT_WITHOUT_FOLDABLE_STRING", parameters = Map( "sqlExpr" -> "\"named_struct(boundreference(), a, b, 2.0)\"", "inputExprs" -> "[\"boundreference()\"]") ) checkError( exception = analysisException(CreateNamedStruct(Seq(Literal.create(null, StringType), "a"))), - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL", + condition = "DATATYPE_MISMATCH.UNEXPECTED_NULL", parameters = Map( "sqlExpr" -> "\"named_struct(NULL, a)\"", "exprName" -> "[\"NULL\"]") @@ -562,7 +562,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer test("check types for CreateMap") { checkError( exception = analysisException(CreateMap(Seq("a", "b", 2.0))), - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> "`map`", "expectedNum" -> "2n (n > 0)", @@ -572,7 +572,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer checkError( exception = analysisException(CreateMap(Seq(Literal(1), Literal("a"), Literal(true), Literal("b")))), - errorClass = "DATATYPE_MISMATCH.CREATE_MAP_KEY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.CREATE_MAP_KEY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"map(1, a, true, b)\"", "functionName" -> "`map`", @@ -582,7 +582,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer checkError( exception = analysisException(CreateMap(Seq(Literal("a"), Literal(1), Literal("b"), Literal(true)))), - errorClass = "DATATYPE_MISMATCH.CREATE_MAP_VALUE_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.CREATE_MAP_VALUE_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"map(a, 1, b, true)\"", "functionName" -> "`map`", @@ -599,7 +599,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer exception = intercept[AnalysisException] { assertSuccess(Round($"intField", $"intField")) }, - errorClass = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", parameters = Map( "sqlExpr" -> "\"round(intField, intField)\"", "inputName" -> "`scale`", @@ -610,7 +610,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer exception = intercept[AnalysisException] { assertSuccess(Round($"intField", $"booleanField")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"round(intField, booleanField)\"", "paramIndex" -> ordinalNumber(1), @@ -621,7 +621,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer exception = intercept[AnalysisException] { assertSuccess(Round($"intField", $"mapField")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"round(intField, mapField)\"", "paramIndex" -> ordinalNumber(1), @@ -632,7 +632,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer exception = intercept[AnalysisException] { assertSuccess(Round($"booleanField", $"intField")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"round(booleanField, intField)\"", "paramIndex" -> ordinalNumber(0), @@ -646,7 +646,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer exception = intercept[AnalysisException] { assertSuccess(BRound($"intField", $"intField")) }, - errorClass = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", parameters = Map( "sqlExpr" -> "\"bround(intField, intField)\"", "inputName" -> "`scale`", @@ -656,7 +656,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer exception = intercept[AnalysisException] { assertSuccess(BRound($"intField", $"booleanField")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"bround(intField, booleanField)\"", "paramIndex" -> ordinalNumber(1), @@ -667,7 +667,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer exception = intercept[AnalysisException] { assertSuccess(BRound($"intField", $"mapField")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"bround(intField, mapField)\"", "paramIndex" -> ordinalNumber(1), @@ -678,7 +678,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer exception = intercept[AnalysisException] { assertSuccess(BRound($"booleanField", $"intField")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"bround(booleanField, intField)\"", "paramIndex" -> ordinalNumber(0), @@ -806,7 +806,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer exception = intercept[SparkException] { wsd.checkInputDataTypes() }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> ("Cannot use an UnspecifiedFrame. " + "This should have been converted during analysis.")) ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala index 1fd81349ac720..1ae3e3fa68603 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala @@ -62,7 +62,7 @@ class LookupFunctionsSuite extends PlanTest { } checkError( exception = cause, - errorClass = "UNRESOLVED_ROUTINE", + condition = "UNRESOLVED_ROUTINE", parameters = Map( "routineName" -> "`undefined_fn`", "searchPath" -> "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`db1`]")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala index 6bc0350a5785d..02543c9fba539 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala @@ -98,18 +98,18 @@ class NamedParameterFunctionSuite extends AnalysisTest { } test("DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT") { - val errorClass = + val condition = "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.BOTH_POSITIONAL_AND_NAMED" checkError( exception = parseRearrangeException( signature, Seq(k1Arg, k2Arg, k3Arg, k4Arg, namedK1Arg), "foo"), - errorClass = errorClass, + condition = condition, parameters = Map("routineName" -> toSQLId("foo"), "parameterName" -> toSQLId("k1")) ) checkError( exception = parseRearrangeException( signature, Seq(k1Arg, k2Arg, k3Arg, k4Arg, k4Arg), "foo"), - errorClass = "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + condition = "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", parameters = Map("routineName" -> toSQLId("foo"), "parameterName" -> toSQLId("k4")) ) } @@ -117,7 +117,7 @@ class NamedParameterFunctionSuite extends AnalysisTest { test("REQUIRED_PARAMETER_NOT_FOUND") { checkError( exception = parseRearrangeException(signature, Seq(k1Arg, k2Arg, k3Arg), "foo"), - errorClass = "REQUIRED_PARAMETER_NOT_FOUND", + condition = "REQUIRED_PARAMETER_NOT_FOUND", parameters = Map( "routineName" -> toSQLId("foo"), "parameterName" -> toSQLId("k4"), "index" -> "2")) } @@ -126,7 +126,7 @@ class NamedParameterFunctionSuite extends AnalysisTest { checkError( exception = parseRearrangeException(signature, Seq(k1Arg, k2Arg, k3Arg, k4Arg, NamedArgumentExpression("k5", Literal("k5"))), "foo"), - errorClass = "UNRECOGNIZED_PARAMETER_NAME", + condition = "UNRECOGNIZED_PARAMETER_NAME", parameters = Map("routineName" -> toSQLId("foo"), "argumentName" -> toSQLId("k5"), "proposal" -> (toSQLId("k1") + " " + toSQLId("k2") + " " + toSQLId("k3"))) ) @@ -136,7 +136,7 @@ class NamedParameterFunctionSuite extends AnalysisTest { checkError( exception = parseRearrangeException(signature, Seq(k2Arg, k3Arg, k1Arg, k4Arg), "foo"), - errorClass = "UNEXPECTED_POSITIONAL_ARGUMENT", + condition = "UNEXPECTED_POSITIONAL_ARGUMENT", parameters = Map("routineName" -> toSQLId("foo"), "parameterName" -> toSQLId("k3")) ) } @@ -147,7 +147,7 @@ class NamedParameterFunctionSuite extends AnalysisTest { s" All required arguments should come before optional arguments." checkError( exception = parseRearrangeException(illegalSignature, args, "foo"), - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> errorMessage) ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala index 5809d1e04b9cf..6e911324e0759 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala @@ -82,7 +82,7 @@ class ResolveLambdaVariablesSuite extends PlanTest { checkError( exception = intercept[AnalysisException](Analyzer.execute(p)), - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.DUPLICATE_ARG_NAMES", + condition = "INVALID_LAMBDA_FUNCTION_CALL.DUPLICATE_ARG_NAMES", parameters = Map( "args" -> "`x`, `x`", "caseSensitiveConfig" -> "\"spark.sql.caseSensitive\"") @@ -96,7 +96,7 @@ class ResolveLambdaVariablesSuite extends PlanTest { checkError( exception = intercept[AnalysisException](Analyzer.execute(p)), - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", + condition = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", parameters = Map("expectedNumArgs" -> "3", "actualNumArgs" -> "1") ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala index 5c843d62d6d7c..b7afc803410cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -108,14 +108,14 @@ class ResolveNaturalJoinSuite extends AnalysisTest { } test("using unresolved attribute") { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( r1.join(r2, UsingJoin(Inner, Seq("d"))), - expectedErrorClass = "UNRESOLVED_USING_COLUMN_FOR_JOIN", + expectedErrorCondition = "UNRESOLVED_USING_COLUMN_FOR_JOIN", expectedMessageParameters = Map( "colName" -> "`d`", "side" -> "left", "suggestion" -> "`a`, `b`")) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( r1.join(r2, UsingJoin(Inner, Seq("b"))), - expectedErrorClass = "UNRESOLVED_USING_COLUMN_FOR_JOIN", + expectedErrorCondition = "UNRESOLVED_USING_COLUMN_FOR_JOIN", expectedMessageParameters = Map( "colName" -> "`b`", "side" -> "right", "suggestion" -> "`a`, `c`")) } @@ -126,17 +126,17 @@ class ResolveNaturalJoinSuite extends AnalysisTest { val usingPlan = r1.join(r2, UsingJoin(Inner, Seq("a")), None) checkAnalysis(usingPlan, expected, caseSensitive = true) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( r1.join(r2, UsingJoin(Inner, Seq("A"))), - expectedErrorClass = "UNRESOLVED_USING_COLUMN_FOR_JOIN", + expectedErrorCondition = "UNRESOLVED_USING_COLUMN_FOR_JOIN", expectedMessageParameters = Map( "colName" -> "`A`", "side" -> "left", "suggestion" -> "`a`, `b`")) } test("using join on nested fields") { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( r5.join(r6, UsingJoin(Inner, Seq("d.f1"))), - expectedErrorClass = "UNRESOLVED_USING_COLUMN_FOR_JOIN", + expectedErrorCondition = "UNRESOLVED_USING_COLUMN_FOR_JOIN", expectedMessageParameters = Map( "colName" -> "`d`.`f1`", "side" -> "left", "suggestion" -> "`d`")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 4e17f4624f7e0..86718ee434311 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -71,7 +71,7 @@ class ResolveSubquerySuite extends AnalysisTest { test("lateral join with ambiguous join conditions") { val plan = lateralJoin(t1, t0.select($"b"), condition = Some($"b" === 1)) - assertAnalysisErrorClass(plan, + assertAnalysisErrorCondition(plan, "AMBIGUOUS_REFERENCE", Map("name" -> "`b`", "referenceNames" -> "[`b`, `b`]") ) } @@ -123,7 +123,7 @@ class ResolveSubquerySuite extends AnalysisTest { // SELECT * FROM t1, LATERAL (SELECT * FROM t2, LATERAL (SELECT a, b, c)) // TODO: support accessing columns from outer outer query. - assertAnalysisErrorClass( + assertAnalysisErrorCondition( lateralJoin(t1, lateralJoin(t2, t0.select($"a", $"b", $"c"))), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", Map("objectName" -> "`a`") @@ -132,25 +132,25 @@ class ResolveSubquerySuite extends AnalysisTest { test("lateral subquery with unresolvable attributes") { // SELECT * FROM t1, LATERAL (SELECT a, c) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( lateralJoin(t1, t0.select($"a", $"c")), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", Map("objectName" -> "`c`") ) // SELECT * FROM t1, LATERAL (SELECT a, b, c, d FROM t2) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( lateralJoin(t1, t2.select($"a", $"b", $"c", $"d")), "UNRESOLVED_COLUMN.WITH_SUGGESTION", Map("objectName" -> "`d`", "proposal" -> "`b`, `c`") ) // SELECT * FROM t1, LATERAL (SELECT * FROM t2, LATERAL (SELECT t1.a)) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( lateralJoin(t1, lateralJoin(t2, t0.select($"t1.a"))), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", Map("objectName" -> "`t1`.`a`") ) // SELECT * FROM t1, LATERAL (SELECT * FROM t2, LATERAL (SELECT a, b)) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( lateralJoin(t1, lateralJoin(t2, t0.select($"a", $"b"))), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", Map("objectName" -> "`a`") @@ -165,7 +165,7 @@ class ResolveSubquerySuite extends AnalysisTest { LateralJoin(t4, LateralSubquery(Project(Seq(xa, ya), t0), Seq(x, y)), Inner, None) ) // Analyzer will try to resolve struct first before subquery alias. - assertAnalysisErrorClass( + assertAnalysisErrorCondition( lateralJoin(t1.as("x"), t4.select($"x.a", $"x.b")), "FIELD_NOT_FOUND", Map("fieldName" -> "`b`", "fields" -> "`a`")) @@ -174,9 +174,9 @@ class ResolveSubquerySuite extends AnalysisTest { test("lateral join with unsupported expressions") { val plan = lateralJoin(t1, t0.select(($"a" + $"b").as("c")), condition = Some(sum($"a") === sum($"c"))) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( plan, - expectedErrorClass = "UNSUPPORTED_EXPR_FOR_OPERATOR", + expectedErrorCondition = "UNSUPPORTED_EXPR_FOR_OPERATOR", expectedMessageParameters = Map("invalidExprSqls" -> "\"sum(a)\", \"sum(c)\"") ) } @@ -206,17 +206,17 @@ class ResolveSubquerySuite extends AnalysisTest { LateralSubquery(Project(Seq(outerA, outerB, b, c), t2.as("t2")), Seq(a, b)), Inner, None) ) // SELECT * FROM t1, LATERAL (SELECT t2.*) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( lateralJoin(t1.as("t1"), t0.select(star("t2"))), - expectedErrorClass = "CANNOT_RESOLVE_STAR_EXPAND", + expectedErrorCondition = "CANNOT_RESOLVE_STAR_EXPAND", expectedMessageParameters = Map("targetString" -> "`t2`", "columns" -> "") ) // Check case sensitivities. // SELECT * FROM t1, LATERAL (SELECT T1.*) val plan = lateralJoin(t1.as("t1"), t0.select(star("T1"))) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( plan, - expectedErrorClass = "CANNOT_RESOLVE_STAR_EXPAND", + expectedErrorCondition = "CANNOT_RESOLVE_STAR_EXPAND", expectedMessageParameters = Map("targetString" -> "`T1`", "columns" -> "") ) assertAnalysisSuccess(plan, caseSensitive = false) @@ -232,9 +232,9 @@ class ResolveSubquerySuite extends AnalysisTest { LateralJoin(t1, LateralSubquery(t0.select(newArray.as(newArray.sql)), Seq(a, b)), Inner, None) ) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( lateralJoin(t1.as("t1"), t0.select(Count(star("t1")))), - expectedErrorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", + expectedErrorCondition = "INVALID_USAGE_OF_STAR_OR_REGEX", expectedMessageParameters = Map("elem" -> "'*'", "prettyName" -> "expression `count`")) } @@ -293,9 +293,9 @@ class ResolveSubquerySuite extends AnalysisTest { :: lv(Symbol("X")) :: Nil)) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = lambdaPlanScanFromTable, - expectedErrorClass = + expectedErrorCondition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.HIGHER_ORDER_FUNCTION", expectedMessageParameters = Map.empty[String, String]) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index ff9c0a1b34f7f..3e9a93dc743df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -740,7 +740,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { testUnaryOperatorInStreamingPlan( "window", Window(Nil, Nil, Nil, _), - errorClass = "NON_TIME_WINDOW_NOT_SUPPORTED_IN_STREAMING") + condition = "NON_TIME_WINDOW_NOT_SUPPORTED_IN_STREAMING") // Output modes with aggregation and non-aggregation plans testOutputMode(Append, shouldSupportAggregation = false, shouldSupportNonAggregation = true) @@ -869,11 +869,11 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { * supports having a batch child plan, forming a batch subplan inside a streaming plan. */ def testUnaryOperatorInStreamingPlan( - operationName: String, - logicalPlanGenerator: LogicalPlan => LogicalPlan, - outputMode: OutputMode = Append, - expectedMsg: String = "", - errorClass: String = ""): Unit = { + operationName: String, + logicalPlanGenerator: LogicalPlan => LogicalPlan, + outputMode: OutputMode = Append, + expectedMsg: String = "", + condition: String = ""): Unit = { val expectedMsgs = if (expectedMsg.isEmpty) Seq(operationName) else Seq(expectedMsg) @@ -882,7 +882,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { wrapInStreaming(logicalPlanGenerator(streamRelation)), outputMode, expectedMsgs, - errorClass) + condition) assertSupportedInStreamingPlan( s"$operationName with batch relation", @@ -1030,11 +1030,11 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { plan: LogicalPlan, outputMode: OutputMode, expectedMsgs: Seq[String], - errorClass: String = ""): Unit = { + condition: String = ""): Unit = { testError( s"streaming plan - $name: not supported", expectedMsgs :+ "streaming" :+ "DataFrame" :+ "Dataset" :+ "not supported", - errorClass) { + condition) { UnsupportedOperationChecker.checkForStreaming(wrapInStreaming(plan), outputMode) } } @@ -1052,7 +1052,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { exception = intercept[AnalysisException] { UnsupportedOperationChecker.checkForStreaming(wrapInStreaming(plan), outputMode) }, - errorClass = "STREAMING_OUTPUT_MODE.UNSUPPORTED_OPERATION", + condition = "STREAMING_OUTPUT_MODE.UNSUPPORTED_OPERATION", sqlState = "42KDE", parameters = Map( "outputMode" -> outputMode.toString.toLowerCase(Locale.ROOT), @@ -1120,7 +1120,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { def testError( testName: String, expectedMsgs: Seq[String], - errorClass: String = "")(testBody: => Unit): Unit = { + condition: String = "")(testBody: => Unit): Unit = { test(testName) { val e = intercept[AnalysisException] { @@ -1132,8 +1132,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { s"actual exception message:\n\t'${e.getMessage}'") } } - if (!errorClass.isEmpty) { - assert(e.getErrorClass == errorClass) + if (!condition.isEmpty) { + assert(e.getErrorClass == condition) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala index 2280463c2f244..29c6c63ecfeab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala @@ -154,16 +154,16 @@ abstract class V2ANSIWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { } } - override def assertAnalysisErrorClass( + override def assertAnalysisErrorCondition( inputPlan: LogicalPlan, - expectedErrorClass: String, + expectedErrorCondition: String, expectedMessageParameters: Map[String, String], queryContext: Array[ExpectedContext] = Array.empty, caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.ANSI.toString) { - super.assertAnalysisErrorClass( + super.assertAnalysisErrorCondition( inputPlan, - expectedErrorClass, + expectedErrorCondition, expectedMessageParameters, queryContext, caseSensitive @@ -191,16 +191,16 @@ abstract class V2StrictWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { } } - override def assertAnalysisErrorClass( + override def assertAnalysisErrorCondition( inputPlan: LogicalPlan, - expectedErrorClass: String, + expectedErrorCondition: String, expectedMessageParameters: Map[String, String], queryContext: Array[ExpectedContext] = Array.empty, caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.STRICT.toString) { - super.assertAnalysisErrorClass( + super.assertAnalysisErrorCondition( inputPlan, - expectedErrorClass, + expectedErrorCondition, expectedMessageParameters, queryContext, caseSensitive @@ -212,9 +212,9 @@ abstract class V2StrictWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { val parsedPlan = byName(table, widerTable) assertNotResolved(parsedPlan) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", expectedMessageParameters = Map( "tableName" -> "`table-name`", "colName" -> "`x`", @@ -235,9 +235,9 @@ abstract class V2StrictWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { val parsedPlan = byName(xRequiredTable, query) assertNotResolved(parsedPlan) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", expectedMessageParameters = Map( "tableName" -> "`table-name`", "colName" -> "`x`", @@ -254,9 +254,9 @@ abstract class V2StrictWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { val parsedPlan = byPosition(table, widerTable) assertNotResolved(parsedPlan) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", expectedMessageParameters = Map( "tableName" -> "`table-name`", "colName" -> "`x`", @@ -277,9 +277,9 @@ abstract class V2StrictWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { val parsedPlan = byPosition(xRequiredTable, query) assertNotResolved(parsedPlan) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", expectedMessageParameters = Map( "tableName" -> "`table-name`", "colName" -> "`x`", @@ -421,9 +421,9 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`") ) } @@ -436,9 +436,9 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`") ) } @@ -499,9 +499,9 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { val parsedPlan = byName(requiredTable, query) assertNotResolved(parsedPlan) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`") ) } @@ -514,9 +514,9 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`") ) } @@ -546,9 +546,9 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = parsedPlan, - expectedErrorClass = "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS", + expectedErrorCondition = "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS", expectedMessageParameters = Map( "tableName" -> "`table-name`", "tableColumns" -> "`x`, `y`", @@ -561,9 +561,9 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { val query = TestRelation(Seq($"b".struct($"y".int, $"x".int, $"z".int), $"a".int)) val writePlan = byName(table, query) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( writePlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS", expectedMessageParameters = Map( "tableName" -> "`table-name`", "colName" -> "`b`", @@ -636,9 +636,9 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { val parsedPlan = byPosition(requiredTable, query) assertNotResolved(parsedPlan) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = parsedPlan, - expectedErrorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + expectedErrorCondition = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", expectedMessageParameters = Map( "tableName" -> "`table-name`", "tableColumns" -> "`x`, `y`", @@ -654,9 +654,9 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { val parsedPlan = byPosition(table, query) assertNotResolved(parsedPlan) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = parsedPlan, - expectedErrorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + expectedErrorCondition = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", expectedMessageParameters = Map( "tableName" -> "`table-name`", "tableColumns" -> "`x`, `y`", @@ -693,9 +693,9 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = parsedPlan, - expectedErrorClass = "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS", + expectedErrorCondition = "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS", expectedMessageParameters = Map( "tableName" -> "`table-name`", "tableColumns" -> "`x`, `y`", @@ -740,9 +740,9 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { withClue("byName") { val parsedPlan = byName(tableWithStructCol, query) assertNotResolved(parsedPlan) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`col`.`a`") ) } @@ -792,9 +792,9 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { val parsedPlan = if (byNameResolution) byName(table, query) else byPosition(table, query) assertNotResolved(parsedPlan) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS", expectedMessageParameters = Map( "tableName" -> "`table-name`", "colName" -> "`b`.`n2`", @@ -821,9 +821,9 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { val parsedPlan = if (byNameResolution) byName(table, query) else byPosition(table, query) assertNotResolved(parsedPlan) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS", expectedMessageParameters = Map( "tableName" -> "`table-name`", "colName" -> "`arr`.`element`", @@ -854,9 +854,9 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { val parsedPlan = if (byNameResolution) byName(table, query) else byPosition(table, query) assertNotResolved(parsedPlan) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS", expectedMessageParameters = Map( "tableName" -> "`table-name`", "colName" -> "`m`.`key`", @@ -887,9 +887,9 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { val parsedPlan = if (byNameResolution) byName(table, query) else byPosition(table, query) assertNotResolved(parsedPlan) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS", expectedMessageParameters = Map( "tableName" -> "`table-name`", "colName" -> "`m`.`value`", @@ -921,17 +921,17 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertNotResolved(parsedPlan) if (byNameResolution) { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", expectedMessageParameters = Map( "tableName" -> "`table-name`", "colName" -> "`b`.`n2`.`dn3`") ) } else { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.STRUCT_MISSING_FIELDS", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.STRUCT_MISSING_FIELDS", expectedMessageParameters = Map( "tableName" -> "`table-name`", "colName" -> "`b`.`n2`", @@ -964,17 +964,17 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertNotResolved(parsedPlan) if (byNameResolution) { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", expectedMessageParameters = Map( "tableName" -> "`table-name`", "colName" -> "`arr`.`element`.`y`") ) } else { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.STRUCT_MISSING_FIELDS", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.STRUCT_MISSING_FIELDS", expectedMessageParameters = Map( "tableName" -> "`table-name`", "colName" -> "`arr`.`element`", @@ -1011,17 +1011,17 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertNotResolved(parsedPlan) if (byNameResolution) { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", expectedMessageParameters = Map( "tableName" -> "`table-name`", "colName" -> "`m`.`key`.`y`") ) } else { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.STRUCT_MISSING_FIELDS", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.STRUCT_MISSING_FIELDS", expectedMessageParameters = Map( "tableName" -> "`table-name`", "colName" -> "`m`.`key`", @@ -1058,17 +1058,17 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertNotResolved(parsedPlan) if (byNameResolution) { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", expectedMessageParameters = Map( "tableName" -> "`table-name`", "colName" -> "`m`.`value`.`y`") ) } else { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, - expectedErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.STRUCT_MISSING_FIELDS", + expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.STRUCT_MISSING_FIELDS", expectedMessageParameters = Map( "tableName" -> "`table-name`", "colName" -> "`m`.`value`", @@ -1363,7 +1363,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { LessThanOrEqual(UnresolvedAttribute(Seq("a")), Literal(15.0d))) assertNotResolved(parsedPlan) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan, "UNRESOLVED_COLUMN.WITH_SUGGESTION", Map("objectName" -> "`a`", "proposal" -> "`x`, `y`") @@ -1376,7 +1376,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { val parsedPlan2 = OverwriteByExpression.byPosition(tableAcceptAnySchema, query, LessThanOrEqual(UnresolvedAttribute(Seq("a")), Literal(15.0d))) assertNotResolved(parsedPlan2) - assertAnalysisErrorClass( + assertAnalysisErrorCondition( parsedPlan2, "UNRESOLVED_COLUMN.WITH_SUGGESTION", Map("objectName" -> "`a`", "proposal" -> "`x`, `y`") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 7e2bad484b3a6..fbe63f71ae029 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -121,7 +121,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { exception = intercept[AnalysisException] { func(name) }, - errorClass = "INVALID_SCHEMA_OR_RELATION_NAME", + condition = "INVALID_SCHEMA_OR_RELATION_NAME", parameters = Map("name" -> toSQLId(name)) ) } @@ -171,7 +171,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { exception = intercept[AnalysisException] { ResolveDefaultColumns.analyze(columnC, statementType) }, - errorClass = "INVALID_DEFAULT_VALUE.UNRESOLVED_EXPRESSION", + condition = "INVALID_DEFAULT_VALUE.UNRESOLVED_EXPRESSION", parameters = Map( "statement" -> "CREATE TABLE", "colName" -> "`c`", @@ -180,7 +180,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { exception = intercept[AnalysisException] { ResolveDefaultColumns.analyze(columnD, statementType) }, - errorClass = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", + condition = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", parameters = Map( "statement" -> "CREATE TABLE", "colName" -> "`d`", @@ -189,7 +189,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { exception = intercept[AnalysisException] { ResolveDefaultColumns.analyze(columnE, statementType) }, - errorClass = "INVALID_DEFAULT_VALUE.DATA_TYPE", + condition = "INVALID_DEFAULT_VALUE.DATA_TYPE", parameters = Map( "statement" -> "CREATE TABLE", "colName" -> "`e`", @@ -589,7 +589,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { sessionCatalog.alterTableDataSchema( TableIdentifier("t1", Some("default")), StructType(oldTab.dataSchema.drop(1))) }, - errorClass = "_LEGACY_ERROR_TEMP_1071", + condition = "_LEGACY_ERROR_TEMP_1071", parameters = Map("nonExistentColumnNames" -> "[col1]")) } } @@ -817,14 +817,14 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { exception = intercept[NoSuchTableException] { catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1")) }, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`default`.`view1`") ) checkError( exception = intercept[NoSuchTableException] { catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) }, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`default`.`view1`") ) @@ -838,7 +838,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { exception = intercept[NoSuchTableException] { catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) }, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`default`.`view1`") ) } @@ -1000,7 +1000,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { TableIdentifier("tbl2", Some("db2")), Seq(part1, partWithLessColumns), ignoreIfExists = false) }, - errorClass = "_LEGACY_ERROR_TEMP_1232", + condition = "_LEGACY_ERROR_TEMP_1232", parameters = Map( "specKeys" -> "a", "partitionColumnNames" -> "a, b", @@ -1011,7 +1011,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { TableIdentifier("tbl2", Some("db2")), Seq(part1, partWithMoreColumns), ignoreIfExists = true) }, - errorClass = "_LEGACY_ERROR_TEMP_1232", + condition = "_LEGACY_ERROR_TEMP_1232", parameters = Map( "specKeys" -> "a, b, c", "partitionColumnNames" -> "a, b", @@ -1022,7 +1022,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { TableIdentifier("tbl2", Some("db2")), Seq(partWithUnknownColumns, part1), ignoreIfExists = true) }, - errorClass = "_LEGACY_ERROR_TEMP_1232", + condition = "_LEGACY_ERROR_TEMP_1232", parameters = Map( "specKeys" -> "a, unknown", "partitionColumnNames" -> "a, b", @@ -1033,7 +1033,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { TableIdentifier("tbl2", Some("db2")), Seq(partWithEmptyValue, part1), ignoreIfExists = true) }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map( "details" -> "The spec ([a=3, b=]) contains an empty partition column value")) } @@ -1126,7 +1126,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { purge = false, retainData = false) }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map( "details" -> ("The spec (a, b, c) must be contained within the partition " + s"spec (a, b) defined in table '`$SESSION_CATALOG_NAME`.`db2`.`tbl2`'"))) @@ -1139,7 +1139,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { purge = false, retainData = false) }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map( "details" -> ("The spec (a, unknown) must be contained within the partition " + s"spec (a, b) defined in table '`$SESSION_CATALOG_NAME`.`db2`.`tbl2`'"))) @@ -1152,7 +1152,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { purge = false, retainData = false) }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map( "details" -> "The spec ([a=3, b=]) contains an empty partition column value")) } @@ -1192,7 +1192,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { exception = intercept[AnalysisException] { catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithLessColumns.spec) }, - errorClass = "_LEGACY_ERROR_TEMP_1232", + condition = "_LEGACY_ERROR_TEMP_1232", parameters = Map( "specKeys" -> "a", "partitionColumnNames" -> "a, b", @@ -1201,7 +1201,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { exception = intercept[AnalysisException] { catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithMoreColumns.spec) }, - errorClass = "_LEGACY_ERROR_TEMP_1232", + condition = "_LEGACY_ERROR_TEMP_1232", parameters = Map( "specKeys" -> "a, b, c", "partitionColumnNames" -> "a, b", @@ -1210,7 +1210,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { exception = intercept[AnalysisException] { catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithUnknownColumns.spec) }, - errorClass = "_LEGACY_ERROR_TEMP_1232", + condition = "_LEGACY_ERROR_TEMP_1232", parameters = Map( "specKeys" -> "a, unknown", "partitionColumnNames" -> "a, b", @@ -1219,7 +1219,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { exception = intercept[AnalysisException] { catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithEmptyValue.spec) }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map( "details" -> "The spec ([a=3, b=]) contains an empty partition column value")) } @@ -1277,7 +1277,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { TableIdentifier("tbl1", Some("db2")), Seq(part1.spec), Seq(partWithLessColumns.spec)) }, - errorClass = "_LEGACY_ERROR_TEMP_1232", + condition = "_LEGACY_ERROR_TEMP_1232", parameters = Map( "specKeys" -> "a", "partitionColumnNames" -> "a, b", @@ -1288,7 +1288,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { TableIdentifier("tbl1", Some("db2")), Seq(part1.spec), Seq(partWithMoreColumns.spec)) }, - errorClass = "_LEGACY_ERROR_TEMP_1232", + condition = "_LEGACY_ERROR_TEMP_1232", parameters = Map( "specKeys" -> "a, b, c", "partitionColumnNames" -> "a, b", @@ -1299,7 +1299,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { TableIdentifier("tbl1", Some("db2")), Seq(part1.spec), Seq(partWithUnknownColumns.spec)) }, - errorClass = "_LEGACY_ERROR_TEMP_1232", + condition = "_LEGACY_ERROR_TEMP_1232", parameters = Map( "specKeys" -> "a, unknown", "partitionColumnNames" -> "a, b", @@ -1310,7 +1310,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { TableIdentifier("tbl1", Some("db2")), Seq(part1.spec), Seq(partWithEmptyValue.spec)) }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map( "details" -> "The spec ([a=3, b=]) contains an empty partition column value")) } @@ -1364,7 +1364,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { exception = intercept[AnalysisException] { catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithLessColumns)) }, - errorClass = "_LEGACY_ERROR_TEMP_1232", + condition = "_LEGACY_ERROR_TEMP_1232", parameters = Map( "specKeys" -> "a", "partitionColumnNames" -> "a, b", @@ -1373,7 +1373,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { exception = intercept[AnalysisException] { catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithMoreColumns)) }, - errorClass = "_LEGACY_ERROR_TEMP_1232", + condition = "_LEGACY_ERROR_TEMP_1232", parameters = Map( "specKeys" -> "a, b, c", "partitionColumnNames" -> "a, b", @@ -1382,7 +1382,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { exception = intercept[AnalysisException] { catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithUnknownColumns)) }, - errorClass = "_LEGACY_ERROR_TEMP_1232", + condition = "_LEGACY_ERROR_TEMP_1232", parameters = Map( "specKeys" -> "a, unknown", "partitionColumnNames" -> "a, b", @@ -1391,7 +1391,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { exception = intercept[AnalysisException] { catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithEmptyValue)) }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map( "details" -> "The spec ([a=3, b=]) contains an empty partition column value")) } @@ -1423,7 +1423,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), Some(partWithMoreColumns.spec)) }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map( "details" -> ("The spec (a, b, c) must be contained within the partition spec (a, b) " + s"defined in table '`$SESSION_CATALOG_NAME`.`db2`.`tbl2`'"))) @@ -1432,7 +1432,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), Some(partWithUnknownColumns.spec)) }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map( "details" -> ("The spec (a, unknown) must be contained within the partition " + s"spec (a, b) defined in table '`$SESSION_CATALOG_NAME`.`db2`.`tbl2`'"))) @@ -1441,7 +1441,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), Some(partWithEmptyValue.spec)) }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map( "details" -> "The spec ([a=3, b=]) contains an empty partition column value")) } @@ -1471,7 +1471,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(partWithMoreColumns.spec)) }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map( "details" -> ("The spec (a, b, c) must be contained within the partition spec (a, b) " + s"defined in table '`$SESSION_CATALOG_NAME`.`db2`.`tbl2`'"))) @@ -1480,7 +1480,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(partWithUnknownColumns.spec)) }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map( "details" -> ("The spec (a, unknown) must be contained within the partition " + s"spec (a, b) defined in table '`$SESSION_CATALOG_NAME`.`db2`.`tbl2`'"))) @@ -1489,7 +1489,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(partWithEmptyValue.spec)) }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map( "details" -> "The spec ([a=3, b=]) contains an empty partition column value")) } @@ -1582,7 +1582,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { newFunc("temp1", None), overrideIfExists = false, functionBuilder = Some(tempFunc3)) } checkError(e, - errorClass = "ROUTINE_ALREADY_EXISTS", + condition = "ROUTINE_ALREADY_EXISTS", parameters = Map("routineName" -> "`temp1`", "newRoutineType" -> "routine", "existingRoutineType" -> "routine")) @@ -1601,7 +1601,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { overrideIfExists = false, None) }, - errorClass = "CANNOT_LOAD_FUNCTION_CLASS", + condition = "CANNOT_LOAD_FUNCTION_CLASS", parameters = Map( "className" -> "function_class_cannot_load", "functionName" -> "`temp2`" @@ -1712,14 +1712,14 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { exception = intercept[NoSuchFunctionException] { catalog.lookupFunction(FunctionIdentifier("func1"), arguments) }, - errorClass = "ROUTINE_NOT_FOUND", + condition = "ROUTINE_NOT_FOUND", parameters = Map("routineName" -> "`default`.`func1`") ) checkError( exception = intercept[NoSuchTempFunctionException] { catalog.dropTempFunction("func1", ignoreIfNotExists = false) }, - errorClass = "ROUTINE_NOT_FOUND", + condition = "ROUTINE_NOT_FOUND", parameters = Map("routineName" -> "`func1`") ) catalog.dropTempFunction("func1", ignoreIfNotExists = true) @@ -1728,7 +1728,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { exception = intercept[NoSuchTempFunctionException] { catalog.dropTempFunction("func2", ignoreIfNotExists = false) }, - errorClass = "ROUTINE_NOT_FOUND", + condition = "ROUTINE_NOT_FOUND", parameters = Map("routineName" -> "`func2`") ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala index 7d9015e566a8c..e8239c7523948 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala @@ -38,7 +38,7 @@ class CSVExprUtilsSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException]{ CSVExprUtils.toDelimiterStr(null) }, - errorClass = "INVALID_DELIMITER_VALUE.NULL_VALUE", + condition = "INVALID_DELIMITER_VALUE.NULL_VALUE", parameters = Map.empty) } @@ -47,7 +47,7 @@ class CSVExprUtilsSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException]{ CSVExprUtils.toChar("ab") }, - errorClass = "INVALID_DELIMITER_VALUE.DELIMITER_LONGER_THAN_EXPECTED", + condition = "INVALID_DELIMITER_VALUE.DELIMITER_LONGER_THAN_EXPECTED", parameters = Map("str" -> "ab")) } @@ -56,7 +56,7 @@ class CSVExprUtilsSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException]{ CSVExprUtils.toChar("""\1""") }, - errorClass = "INVALID_DELIMITER_VALUE.UNSUPPORTED_SPECIAL_CHARACTER", + condition = "INVALID_DELIMITER_VALUE.UNSUPPORTED_SPECIAL_CHARACTER", parameters = Map("str" -> """\1""")) } @@ -65,7 +65,7 @@ class CSVExprUtilsSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException]{ CSVExprUtils.toChar("""\""") }, - errorClass = "INVALID_DELIMITER_VALUE.SINGLE_BACKSLASH", + condition = "INVALID_DELIMITER_VALUE.SINGLE_BACKSLASH", parameters = Map.empty) } @@ -74,7 +74,7 @@ class CSVExprUtilsSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException]{ CSVExprUtils.toChar("") }, - errorClass = "INVALID_DELIMITER_VALUE.EMPTY_STRING", + condition = "INVALID_DELIMITER_VALUE.EMPTY_STRING", parameters = Map.empty) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala index cbc98d2f23020..514b529ea8cc0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -308,7 +308,7 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { check(filters = Seq(EqualTo("invalid attr", 1)), expected = None) }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`invalid attr`", "fields" -> "`i`")) checkError( @@ -319,7 +319,7 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { filters = Seq(EqualTo("i", 1)), expected = Some(InternalRow.empty)) }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`i`", "fields" -> "")) } @@ -374,7 +374,7 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { check(new UnivocityParser(StructType(Seq.empty), optionsWithPattern(false))) }, - errorClass = "INVALID_DATETIME_PATTERN.ILLEGAL_CHARACTER", + condition = "INVALID_DATETIME_PATTERN.ILLEGAL_CHARACTER", parameters = Map( "c" -> "n", "pattern" -> "invalid")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala index e852b474aa18c..b7309923ac206 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala @@ -55,7 +55,7 @@ class EncoderErrorMessageSuite extends SparkFunSuite { checkError( exception = intercept[ SparkUnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable1]()), - errorClass = "ENCODER_NOT_FOUND", + condition = "ENCODER_NOT_FOUND", parameters = Map( "typeName" -> "org.apache.spark.sql.catalyst.encoders.NonEncodable", "docroot" -> SPARK_DOC_ROOT) @@ -64,7 +64,7 @@ class EncoderErrorMessageSuite extends SparkFunSuite { checkError( exception = intercept[ SparkUnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable2]()), - errorClass = "ENCODER_NOT_FOUND", + condition = "ENCODER_NOT_FOUND", parameters = Map( "typeName" -> "org.apache.spark.sql.catalyst.encoders.NonEncodable", "docroot" -> SPARK_DOC_ROOT) @@ -73,7 +73,7 @@ class EncoderErrorMessageSuite extends SparkFunSuite { checkError( exception = intercept[ SparkUnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable3]()), - errorClass = "ENCODER_NOT_FOUND", + condition = "ENCODER_NOT_FOUND", parameters = Map( "typeName" -> "org.apache.spark.sql.catalyst.encoders.NonEncodable", "docroot" -> SPARK_DOC_ROOT) @@ -82,7 +82,7 @@ class EncoderErrorMessageSuite extends SparkFunSuite { checkError( exception = intercept[ SparkUnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable4]()), - errorClass = "ENCODER_NOT_FOUND", + condition = "ENCODER_NOT_FOUND", parameters = Map( "typeName" -> "org.apache.spark.sql.catalyst.encoders.NonEncodable", "docroot" -> SPARK_DOC_ROOT) @@ -91,7 +91,7 @@ class EncoderErrorMessageSuite extends SparkFunSuite { checkError( exception = intercept[ SparkUnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable5]()), - errorClass = "ENCODER_NOT_FOUND", + condition = "ENCODER_NOT_FOUND", parameters = Map( "typeName" -> "org.apache.spark.sql.catalyst.encoders.NonEncodable", "docroot" -> SPARK_DOC_ROOT) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index e29609c741633..35a27f41da80a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.encoders import scala.reflect.runtime.universe.TypeTag import org.apache.spark.SparkRuntimeException -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, Encoders} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} @@ -71,9 +71,9 @@ class EncoderResolutionSuite extends PlanTest { } test("real type doesn't match encoder schema but they are compatible: tupled encoder") { - val encoder = ExpressionEncoder.tuple( - ExpressionEncoder[StringLongClass](), - ExpressionEncoder[Long]()) + val encoder = encoderFor(Encoders.tuple( + Encoders.product[StringLongClass], + Encoders.scalaLong)) val attrs = Seq($"a".struct($"a".string, $"b".byte), $"b".int) testFromRow(encoder, attrs, InternalRow(InternalRow(str, 1.toByte), 2)) } @@ -90,7 +90,7 @@ class EncoderResolutionSuite extends PlanTest { val attrs = Seq($"arr".array(StringType)) checkError( exception = intercept[AnalysisException](encoder.resolveAndBind(attrs)), - errorClass = "CANNOT_UP_CAST_DATATYPE", + condition = "CANNOT_UP_CAST_DATATYPE", parameters = Map("expression" -> "array element", "sourceType" -> "\"STRING\"", "targetType" -> "\"BIGINT\"", "details" -> ( @@ -125,7 +125,7 @@ class EncoderResolutionSuite extends PlanTest { val attrs = Seq($"arr".int) checkError( exception = intercept[AnalysisException](encoder.resolveAndBind(attrs)), - errorClass = "UNSUPPORTED_DESERIALIZER.DATA_TYPE_MISMATCH", + condition = "UNSUPPORTED_DESERIALIZER.DATA_TYPE_MISMATCH", parameters = Map("desiredType" -> "\"ARRAY\"", "dataType" -> "\"INT\"")) } @@ -134,7 +134,7 @@ class EncoderResolutionSuite extends PlanTest { val attrs = Seq($"arr".array(new StructType().add("c", "int"))) checkError( exception = intercept[AnalysisException](encoder.resolveAndBind(attrs)), - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`a`", "fields" -> "`c`")) } @@ -145,7 +145,7 @@ class EncoderResolutionSuite extends PlanTest { val attrs = Seq($"nestedArr".array(new StructType().add("arr", "int"))) checkError( exception = intercept[AnalysisException](encoder.resolveAndBind(attrs)), - errorClass = "UNSUPPORTED_DESERIALIZER.DATA_TYPE_MISMATCH", + condition = "UNSUPPORTED_DESERIALIZER.DATA_TYPE_MISMATCH", parameters = Map("desiredType" -> "\"ARRAY\"", "dataType" -> "\"INT\"")) } @@ -154,7 +154,7 @@ class EncoderResolutionSuite extends PlanTest { .add("arr", ArrayType(new StructType().add("c", "int"))))) checkError( exception = intercept[AnalysisException](encoder.resolveAndBind(attrs)), - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`a`", "fields" -> "`c`")) } } @@ -183,7 +183,7 @@ class EncoderResolutionSuite extends PlanTest { val attrs = Seq($"a".string, $"b".long, $"c".int) checkError( exception = intercept[AnalysisException](encoder.resolveAndBind(attrs)), - errorClass = "UNSUPPORTED_DESERIALIZER.FIELD_NUMBER_MISMATCH", + condition = "UNSUPPORTED_DESERIALIZER.FIELD_NUMBER_MISMATCH", parameters = Map("schema" -> "\"STRUCT\"", "ordinal" -> "2")) } @@ -192,7 +192,7 @@ class EncoderResolutionSuite extends PlanTest { val attrs = Seq($"a".string) checkError( exception = intercept[AnalysisException](encoder.resolveAndBind(attrs)), - errorClass = "UNSUPPORTED_DESERIALIZER.FIELD_NUMBER_MISMATCH", + condition = "UNSUPPORTED_DESERIALIZER.FIELD_NUMBER_MISMATCH", parameters = Map("schema" -> "\"STRUCT\"", "ordinal" -> "2")) } @@ -205,7 +205,7 @@ class EncoderResolutionSuite extends PlanTest { val attrs = Seq($"a".string, $"b".struct($"x".long, $"y".string, $"z".int)) checkError( exception = intercept[AnalysisException](encoder.resolveAndBind(attrs)), - errorClass = "UNSUPPORTED_DESERIALIZER.FIELD_NUMBER_MISMATCH", + condition = "UNSUPPORTED_DESERIALIZER.FIELD_NUMBER_MISMATCH", parameters = Map("schema" -> "\"STRUCT\"", "ordinal" -> "2")) } @@ -214,7 +214,7 @@ class EncoderResolutionSuite extends PlanTest { val attrs = Seq($"a".string, $"b".struct($"x".long)) checkError( exception = intercept[AnalysisException](encoder.resolveAndBind(attrs)), - errorClass = "UNSUPPORTED_DESERIALIZER.FIELD_NUMBER_MISMATCH", + condition = "UNSUPPORTED_DESERIALIZER.FIELD_NUMBER_MISMATCH", parameters = Map("schema" -> "\"STRUCT\"", "ordinal" -> "2")) } @@ -233,7 +233,7 @@ class EncoderResolutionSuite extends PlanTest { .foreach { attr => val attrs = Seq(attr) checkError(exception = intercept[AnalysisException](encoder.resolveAndBind(attrs)), - errorClass = "CANNOT_UP_CAST_DATATYPE", + condition = "CANNOT_UP_CAST_DATATYPE", parameters = Map("expression" -> "a", "sourceType" -> ("\"" + attr.dataType.sql + "\""), "targetType" -> "\"STRING\"", "details" -> ( @@ -250,7 +250,7 @@ class EncoderResolutionSuite extends PlanTest { ExpressionEncoder[StringIntClass]().resolveAndBind(Seq($"a".string, $"b".long)) } checkError(exception = e1, - errorClass = "CANNOT_UP_CAST_DATATYPE", + condition = "CANNOT_UP_CAST_DATATYPE", parameters = Map("expression" -> "b", "sourceType" -> ("\"BIGINT\""), "targetType" -> "\"INT\"", "details" -> ( @@ -267,7 +267,7 @@ class EncoderResolutionSuite extends PlanTest { } checkError(exception = e2, - errorClass = "CANNOT_UP_CAST_DATATYPE", + condition = "CANNOT_UP_CAST_DATATYPE", parameters = Map("expression" -> "b.`b`", "sourceType" -> ("\"DECIMAL(38,18)\""), "targetType" -> "\"BIGINT\"", "details" -> ( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 7b8d8be6bbeeb..3b5cbed2cc527 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -321,29 +321,29 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes encodeDecodeTest( 1 -> 10L, "tuple with 2 flat encoders")( - ExpressionEncoder.tuple(ExpressionEncoder[Int](), ExpressionEncoder[Long]())) + encoderFor(Encoders.tuple(Encoders.scalaInt, Encoders.scalaLong))) encodeDecodeTest( (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)), "tuple with 2 product encoders")( - ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData](), ExpressionEncoder[(Int, Long)]())) + encoderFor(Encoders.tuple(Encoders.product[PrimitiveData], Encoders.product[(Int, Long)]))) encodeDecodeTest( (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3), "tuple with flat encoder and product encoder")( - ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData](), ExpressionEncoder[Int]())) + encoderFor(Encoders.tuple(Encoders.product[PrimitiveData], Encoders.scalaInt))) encodeDecodeTest( (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)), "tuple with product encoder and flat encoder")( - ExpressionEncoder.tuple(ExpressionEncoder[Int](), ExpressionEncoder[PrimitiveData]())) + encoderFor(Encoders.tuple(Encoders.scalaInt, Encoders.product[PrimitiveData]))) encodeDecodeTest( (1, (10, 100L)), "nested tuple encoder") { - val intEnc = ExpressionEncoder[Int]() - val longEnc = ExpressionEncoder[Long]() - ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) + val intEnc = Encoders.scalaInt + val longEnc = Encoders.scalaLong + encoderFor(Encoders.tuple(intEnc, Encoders.tuple(intEnc, longEnc))) } // test for value classes @@ -435,7 +435,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes implicitly[ExpressionEncoder[Foo]]) checkError( exception = exception, - errorClass = "ENCODER_NOT_FOUND", + condition = "ENCODER_NOT_FOUND", parameters = Map( "typeName" -> "Any", "docroot" -> SPARK_DOC_ROOT) @@ -468,9 +468,8 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes // test for tupled encoders { - val schema = ExpressionEncoder.tuple( - ExpressionEncoder[Int](), - ExpressionEncoder[(String, Int)]()).schema + val encoder = encoderFor(Encoders.tuple(Encoders.scalaInt, Encoders.product[(String, Int)])) + val schema = encoder.schema assert(schema(0).nullable === false) assert(schema(1).nullable) assert(schema(1).dataType.asInstanceOf[StructType](0).nullable) @@ -496,7 +495,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes assert(e.getCause.isInstanceOf[SparkRuntimeException]) checkError( exception = e.getCause.asInstanceOf[SparkRuntimeException], - errorClass = "NULL_MAP_KEY", + condition = "NULL_MAP_KEY", parameters = Map.empty ) } @@ -507,19 +506,19 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes assert(e.getCause.isInstanceOf[SparkRuntimeException]) checkError( exception = e.getCause.asInstanceOf[SparkRuntimeException], - errorClass = "NULL_MAP_KEY", + condition = "NULL_MAP_KEY", parameters = Map.empty ) } test("throw exception for tuples with more than 22 elements") { - val encoders = (0 to 22).map(_ => Encoders.scalaInt.asInstanceOf[ExpressionEncoder[_]]) + val encoders = (0 to 22).map(_ => Encoders.scalaInt) checkError( exception = intercept[SparkUnsupportedOperationException] { - ExpressionEncoder.tuple(encoders) + Encoders.tupleEncoder(encoders: _*) }, - errorClass = "_LEGACY_ERROR_TEMP_2150", + condition = "_LEGACY_ERROR_TEMP_2150", parameters = Map.empty) } @@ -531,11 +530,11 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes val encoder = ExpressionEncoder(schema, lenient = true) val unexpectedSerializer = NaNvl(encoder.objSerializer, encoder.objSerializer) val exception = intercept[org.apache.spark.SparkRuntimeException] { - new ExpressionEncoder[Row](unexpectedSerializer, encoder.objDeserializer, encoder.clsTag) + new ExpressionEncoder[Row](encoder.encoder, unexpectedSerializer, encoder.objDeserializer) } checkError( exception = exception, - errorClass = "UNEXPECTED_SERIALIZER_FOR_CLASS", + condition = "UNEXPECTED_SERIALIZER_FOR_CLASS", parameters = Map( "className" -> Utils.getSimpleName(encoder.clsTag.runtimeClass), "expr" -> toSQLExpr(unexpectedSerializer)) @@ -552,11 +551,14 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes encodeDecodeTest(FooClassWithEnum(1, FooEnum.E1), "case class with Int and scala Enum") encodeDecodeTest(FooEnum.E1, "scala Enum") - test("transforming encoder") { + + private def testTransformingEncoder( + name: String, + provider: () => Codec[Any, Array[Byte]]): Unit = test(name) { val encoder = ExpressionEncoder(TransformingEncoder( classTag[(Long, Long)], BinaryEncoder, - JavaSerializationCodec)) + provider)) .resolveAndBind() assert(encoder.schema == new StructType().add("value", BinaryType)) val toRow = encoder.createSerializer() @@ -564,6 +566,9 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes assert(fromRow(toRow((11, 14))) == (11, 14)) } + testTransformingEncoder("transforming java serialization encoder", JavaSerializationCodec) + testTransformingEncoder("transforming kryo encoder", KryoSerializationCodec) + // Scala / Java big decimals ---------------------------------------------------------- encodeDecodeTest(BigDecimal(("9" * 20) + "." + "9" * 18), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala index 71fa60b0c0345..29c5bf3b8d2db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala @@ -69,7 +69,7 @@ class AttributeResolutionSuite extends SparkFunSuite { exception = intercept[AnalysisException] { attrs.resolve(Seq("a"), resolver) }, - errorClass = "AMBIGUOUS_REFERENCE", + condition = "AMBIGUOUS_REFERENCE", parameters = Map( "name" -> "`a`", "referenceNames" -> "[`ns1`.`ns2`.`t2`.`a`, `ns1`.`t1`.`a`]" @@ -86,7 +86,7 @@ class AttributeResolutionSuite extends SparkFunSuite { exception = intercept[AnalysisException] { attrs.resolve(Seq("ns1", "t", "a"), resolver) }, - errorClass = "AMBIGUOUS_REFERENCE", + condition = "AMBIGUOUS_REFERENCE", parameters = Map( "name" -> "`ns1`.`t`.`a`", "referenceNames" -> "[`ns1`.`t`.`a`, `ns2`.`ns1`.`t`.`a`]" @@ -108,7 +108,7 @@ class AttributeResolutionSuite extends SparkFunSuite { exception = intercept[AnalysisException] { attrs.resolve(Seq("ns1", "t", "a", "cc"), resolver) }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`cc`", "fields" -> "`aa`, `bb`")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala index 77fdb33e515fc..995b519bd05d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala @@ -103,7 +103,7 @@ class CallMethodViaReflectionSuite extends SparkFunSuite with ExpressionEvalHelp exception = intercept[AnalysisException] { CallMethodViaReflection(Seq.empty).checkInputDataTypes() }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> "`reflect`", "expectedNum" -> "> 1", @@ -114,7 +114,7 @@ class CallMethodViaReflectionSuite extends SparkFunSuite with ExpressionEvalHelp exception = intercept[AnalysisException] { CallMethodViaReflection(Seq(Literal(staticClassName))).checkInputDataTypes() }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> "`reflect`", "expectedNum" -> "> 1", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala index 67a68fc92a300..e87b54339821f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala @@ -1107,8 +1107,10 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { .foreach { interval => checkErrorInExpression[SparkIllegalArgumentException]( cast(Literal.create(interval), YearMonthIntervalType()), - "_LEGACY_ERROR_TEMP_3213", - Map("interval" -> "year-month", "msg" -> "integer overflow")) + "INVALID_INTERVAL_FORMAT.INTERVAL_PARSING", + Map( + "interval" -> "year-month", + "input" -> interval)) } Seq(Byte.MaxValue, Short.MaxValue, Int.MaxValue, Int.MinValue + 1, Int.MinValue) @@ -1176,9 +1178,8 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { val dataType = YearMonthIntervalType() checkErrorInExpression[SparkIllegalArgumentException]( cast(Literal.create(interval), dataType), - "_LEGACY_ERROR_TEMP_3214", + "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING", Map( - "fallBackNotice" -> "", "typeName" -> "interval year to month", "intervalStr" -> "year-month", "supportedFormat" -> "`[+|-]y-m`, `INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH`", @@ -1198,9 +1199,8 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { .foreach { case (interval, dataType) => checkErrorInExpression[SparkIllegalArgumentException]( cast(Literal.create(interval), dataType), - "_LEGACY_ERROR_TEMP_3214", + "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING", Map( - "fallBackNotice" -> "", "typeName" -> dataType.typeName, "intervalStr" -> "year-month", "supportedFormat" -> @@ -1322,10 +1322,8 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { .foreach { case (interval, dataType) => checkErrorInExpression[SparkIllegalArgumentException]( cast(Literal.create(interval), dataType), - "_LEGACY_ERROR_TEMP_3214", - Map("fallBackNotice" -> (", set spark.sql.legacy.fromDayTimeString.enabled" + - " to true to restore the behavior before Spark 3.0."), - "intervalStr" -> "day-time", + "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + Map("intervalStr" -> "day-time", "typeName" -> dataType.typeName, "input" -> interval, "supportedFormat" -> @@ -1348,10 +1346,8 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { .foreach { case (interval, dataType) => checkErrorInExpression[SparkIllegalArgumentException]( cast(Literal.create(interval), dataType), - "_LEGACY_ERROR_TEMP_3214", - Map("fallBackNotice" -> (", set spark.sql.legacy.fromDayTimeString.enabled" + - " to true to restore the behavior before Spark 3.0."), - "intervalStr" -> "day-time", + "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + Map("intervalStr" -> "day-time", "typeName" -> dataType.typeName, "input" -> interval, "supportedFormat" -> diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala index 175dd05d5911e..e34b54c7086cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala @@ -61,7 +61,7 @@ class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("collate on non existing collation") { checkError( exception = intercept[SparkException] { Collate(Literal("abc"), "UTF8_BS") }, - errorClass = "COLLATION_INVALID_NAME", + condition = "COLLATION_INVALID_NAME", sqlState = "42704", parameters = Map("collationName" -> "UTF8_BS", "proposals" -> "UTF8_LCASE")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index c7e995feb5ed8..55148978fa005 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -95,7 +95,7 @@ class CollectionExpressionsSuite } checkError( exception = exception, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map( "message" -> ("The size function doesn't support the operand type " + toSQLType(StringType)) @@ -266,7 +266,7 @@ class CollectionExpressionsSuite checkErrorInExpression[SparkRuntimeException]( MapConcat(Seq(m0, m1)), - errorClass = "DUPLICATED_MAP_KEY", + condition = "DUPLICATED_MAP_KEY", parameters = Map( "key" -> "a", "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") @@ -426,7 +426,7 @@ class CollectionExpressionsSuite checkErrorInExpression[SparkRuntimeException]( MapFromEntries(ai4), - errorClass = "DUPLICATED_MAP_KEY", + condition = "DUPLICATED_MAP_KEY", parameters = Map( "key" -> "1", "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") @@ -458,7 +458,7 @@ class CollectionExpressionsSuite checkErrorInExpression[SparkRuntimeException]( MapFromEntries(as4), - errorClass = "DUPLICATED_MAP_KEY", + condition = "DUPLICATED_MAP_KEY", parameters = Map( "key" -> "a", "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") @@ -720,7 +720,7 @@ class CollectionExpressionsSuite checkEvaluation(Slice(a0, Literal(-1), Literal(2)), Seq(6)) checkErrorInExpression[SparkRuntimeException]( expression = Slice(a0, Literal(1), Literal(-1)), - errorClass = "INVALID_PARAMETER_VALUE.LENGTH", + condition = "INVALID_PARAMETER_VALUE.LENGTH", parameters = Map( "parameter" -> toSQLId("length"), "length" -> (-1).toString, @@ -728,7 +728,7 @@ class CollectionExpressionsSuite )) checkErrorInExpression[SparkRuntimeException]( expression = Slice(a0, Literal(0), Literal(1)), - errorClass = "INVALID_PARAMETER_VALUE.START", + condition = "INVALID_PARAMETER_VALUE.START", parameters = Map( "parameter" -> toSQLId("start"), "functionName" -> toSQLId("slice") @@ -910,7 +910,7 @@ class CollectionExpressionsSuite // SPARK-43393: test Sequence overflow checking checkErrorInExpression[SparkRuntimeException]( new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)), - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", + condition = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", parameters = Map( "numberOfElements" -> (BigInt(Int.MaxValue) - BigInt { Int.MinValue } + 1).toString, "functionName" -> toSQLId("sequence"), @@ -918,7 +918,7 @@ class CollectionExpressionsSuite "parameter" -> toSQLId("count"))) checkErrorInExpression[SparkRuntimeException]( new Sequence(Literal(0L), Literal(Long.MaxValue), Literal(1L)), - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", + condition = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", parameters = Map( "numberOfElements" -> (BigInt(Long.MaxValue) + 1).toString, "functionName" -> toSQLId("sequence"), @@ -926,7 +926,7 @@ class CollectionExpressionsSuite "parameter" -> toSQLId("count"))) checkErrorInExpression[SparkRuntimeException]( new Sequence(Literal(0L), Literal(Long.MinValue), Literal(-1L)), - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", + condition = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", parameters = Map( "numberOfElements" -> ((0 - BigInt(Long.MinValue)) + 1).toString(), "functionName" -> toSQLId("sequence"), @@ -934,7 +934,7 @@ class CollectionExpressionsSuite "parameter" -> toSQLId("count"))) checkErrorInExpression[SparkRuntimeException]( new Sequence(Literal(Long.MinValue), Literal(Long.MaxValue), Literal(1L)), - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", + condition = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", parameters = Map( "numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue } + 1).toString, "functionName" -> toSQLId("sequence"), @@ -942,7 +942,7 @@ class CollectionExpressionsSuite "parameter" -> toSQLId("count"))) checkErrorInExpression[SparkRuntimeException]( new Sequence(Literal(Long.MaxValue), Literal(Long.MinValue), Literal(-1L)), - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", + condition = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", parameters = Map( "numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue } + 1).toString, "functionName" -> toSQLId("sequence"), @@ -950,7 +950,7 @@ class CollectionExpressionsSuite "parameter" -> toSQLId("count"))) checkErrorInExpression[SparkRuntimeException]( new Sequence(Literal(Long.MaxValue), Literal(-1L), Literal(-1L)), - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", + condition = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", parameters = Map( "numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { -1L } + 1).toString, "functionName" -> toSQLId("sequence"), @@ -2293,6 +2293,14 @@ class CollectionExpressionsSuite evaluateWithMutableProjection(Shuffle(ai0, seed2))) assert(evaluateWithUnsafeProjection(Shuffle(ai0, seed1)) !== evaluateWithUnsafeProjection(Shuffle(ai0, seed2))) + + val seed3 = Literal.create(r.nextInt()) + assert(evaluateWithoutCodegen(new Shuffle(ai0, seed3)) === + evaluateWithoutCodegen(new Shuffle(ai0, seed3))) + assert(evaluateWithMutableProjection(new Shuffle(ai0, seed3)) === + evaluateWithMutableProjection(new Shuffle(ai0, seed3))) + assert(evaluateWithUnsafeProjection(new Shuffle(ai0, seed3)) === + evaluateWithUnsafeProjection(new Shuffle(ai0, seed3))) } test("Array Except") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 497b335289b11..7baad5ea92a00 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -285,7 +285,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkErrorInExpression[SparkRuntimeException]( CreateMap(Seq(Literal(1), Literal(2), Literal(1), Literal(3))), - errorClass = "DUPLICATED_MAP_KEY", + condition = "DUPLICATED_MAP_KEY", parameters = Map( "key" -> "1", "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") @@ -328,7 +328,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { exception = intercept[AnalysisException] { map3.checkInputDataTypes() }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> "`map`", "expectedNum" -> "2n (n > 0)", @@ -430,7 +430,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { MapFromArrays( Literal.create(Seq(1, 1), ArrayType(IntegerType)), Literal.create(Seq(2, 3), ArrayType(IntegerType))), - errorClass = "DUPLICATED_MAP_KEY", + condition = "DUPLICATED_MAP_KEY", parameters = Map( "key" -> "1", "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") @@ -483,7 +483,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { exception = intercept[AnalysisException] { namedStruct1.checkInputDataTypes() }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> "`named_struct`", "expectedNum" -> "2n (n > 0)", @@ -556,7 +556,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkErrorInExpression[SparkRuntimeException]( new StringToMap(Literal("a:1,b:2,a:3")), - errorClass = "DUPLICATED_MAP_KEY", + condition = "DUPLICATED_MAP_KEY", parameters = Map( "key" -> "a", "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 7ffb321217024..f4c71a1056939 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types._ @@ -277,4 +278,27 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper assert(!caseWhenObj1.semanticEquals(caseWhenObj2)) assert(!caseWhenObj2.semanticEquals(caseWhenObj1)) } + + test("SPARK-49396 accurate nullability check") { + val trueBranch = (TrueLiteral, Literal(5)) + val normalBranch = (NonFoldableLiteral(true), Literal(10)) + + val nullLiteral = Literal.create(null, BooleanType) + val noElseValue = CaseWhen(normalBranch :: trueBranch :: Nil, None) + assert(!noElseValue.nullable) + val withElseValue = CaseWhen(normalBranch :: trueBranch :: Nil, Some(Literal(1))) + assert(!withElseValue.nullable) + val withNullableElseValue = CaseWhen(normalBranch :: trueBranch :: Nil, Some(nullLiteral)) + assert(!withNullableElseValue.nullable) + val firstTrueNonNullableSecondTrueNullable = CaseWhen(trueBranch :: + (TrueLiteral, nullLiteral) :: Nil, None) + assert(!firstTrueNonNullableSecondTrueNullable.nullable) + val firstTrueNullableSecondTrueNonNullable = CaseWhen((TrueLiteral, nullLiteral) :: + trueBranch :: Nil, None) + assert(firstTrueNullableSecondTrueNonNullable.nullable) + val hasNullInNotTrueBranch = CaseWhen(trueBranch :: (FalseLiteral, nullLiteral) :: Nil, None) + assert(!hasNullInNotTrueBranch.nullable) + val noTrueBranch = CaseWhen(normalBranch :: Nil, Literal(1)) + assert(!noTrueBranch.nullable) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index f1c04c7e33821..21ae35146282b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -2033,12 +2033,12 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkErrorInExpression[SparkArithmeticException](TimestampAdd("DAY", Literal(106751992), Literal(0L, TimestampType)), - errorClass = "DATETIME_OVERFLOW", + condition = "DATETIME_OVERFLOW", parameters = Map("operation" -> "add 106751992 DAY to TIMESTAMP '1970-01-01 00:00:00'")) checkErrorInExpression[SparkArithmeticException](TimestampAdd("QUARTER", Literal(1431655764), Literal(0L, TimestampType)), - errorClass = "DATETIME_OVERFLOW", + condition = "DATETIME_OVERFLOW", parameters = Map("operation" -> "add 1431655764 QUARTER to TIMESTAMP '1970-01-01 00:00:00'")) } @@ -2103,11 +2103,11 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("datetime function CurrentDate and localtimestamp are Unevaluable") { checkError(exception = intercept[SparkException] { CurrentDate(UTC_OPT).eval(EmptyRow) }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "Cannot evaluate expression: current_date(Some(UTC))")) checkError(exception = intercept[SparkException] { LocalTimestamp(UTC_OPT).eval(EmptyRow) }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "Cannot evaluate expression: localtimestamp(Some(UTC))")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a063e53486ad8..02c7ed727a530 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -154,22 +154,22 @@ trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestB protected def checkErrorInExpression[T <: SparkThrowable : ClassTag]( expression: => Expression, - errorClass: String, + condition: String, parameters: Map[String, String] = Map.empty): Unit = { - checkErrorInExpression[T](expression, InternalRow.empty, errorClass, parameters) + checkErrorInExpression[T](expression, InternalRow.empty, condition, parameters) } protected def checkErrorInExpression[T <: SparkThrowable : ClassTag]( expression: => Expression, inputRow: InternalRow, - errorClass: String): Unit = { - checkErrorInExpression[T](expression, inputRow, errorClass, Map.empty[String, String]) + condition: String): Unit = { + checkErrorInExpression[T](expression, inputRow, condition, Map.empty[String, String]) } protected def checkErrorInExpression[T <: SparkThrowable : ClassTag]( expression: => Expression, inputRow: InternalRow, - errorClass: String, + condition: String, parameters: Map[String, String]): Unit = { def checkException(eval: => Unit, testMode: String): Unit = { @@ -179,7 +179,7 @@ trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestB withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { checkError( exception = intercept[T](eval), - errorClass = errorClass, + condition = condition, parameters = parameters ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala index dd512b0d83e5c..b6a3d61cb13a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala @@ -82,7 +82,7 @@ class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { exception = intercept[AnalysisException] { Stack(Seq(Literal(1))).checkInputDataTypes() }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> "`stack`", "expectedNum" -> "> 1", @@ -93,7 +93,7 @@ class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { exception = intercept[AnalysisException] { Stack(Seq(Literal(1.0))).checkInputDataTypes() }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> "`stack`", "expectedNum" -> "> 1", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index c06705606567d..cc36cd73d6d77 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -471,7 +471,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkErrorInExpression[SparkRuntimeException]( transformKeys(ai0, modKey), - errorClass = "DUPLICATED_MAP_KEY", + condition = "DUPLICATED_MAP_KEY", parameters = Map( "key" -> "1", "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") @@ -858,7 +858,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT.key -> "false") { checkErrorInExpression[SparkException]( expression = arraySort(Literal.create(Seq(3, 1, 1, 2)), comparator), - errorClass = "COMPARATOR_RETURNS_NULL", + condition = "COMPARATOR_RETURNS_NULL", parameters = Map("firstValue" -> "1", "secondValue" -> "1") ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala index ff5ffe4e869a0..7caf23490a0ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala @@ -351,7 +351,7 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Seq( (Period.ofMonths(2), Int.MaxValue) -> "overflow", - (Period.ofMonths(Int.MinValue), 10d) -> "not in range", + (Period.ofMonths(Int.MinValue), 10d) -> "out of range", (Period.ofMonths(-100), Float.NaN) -> "input is infinite or NaN", (Period.ofMonths(200), Double.PositiveInfinity) -> "input is infinite or NaN", (Period.ofMonths(-200), Float.NegativeInfinity) -> "input is infinite or NaN" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index a23e7f44a48d1..adb39fcd568c9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -448,7 +448,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with }.getCause checkError( exception = exception.asInstanceOf[SparkException], - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map("badRecord" -> "[null]", "failFastMode" -> "FAILFAST") ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index d42e0b7d681db..b351d69d3a0bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -97,7 +97,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { exception = intercept[SparkException] { Literal.default(errType) }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> s"No default value for type: ${toSQLType(errType)}.") ) }) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index 1f37886f44258..40e6fe1a90a63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -71,6 +71,13 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { evaluateWithMutableProjection(Uuid(seed2))) assert(evaluateWithUnsafeProjection(Uuid(seed1)) !== evaluateWithUnsafeProjection(Uuid(seed2))) + + val seed3 = Literal.create(r.nextInt()) + assert(evaluateWithoutCodegen(new Uuid(seed3)) === evaluateWithoutCodegen(new Uuid(seed3))) + assert(evaluateWithMutableProjection(new Uuid(seed3)) === + evaluateWithMutableProjection(new Uuid(seed3))) + assert(evaluateWithUnsafeProjection(new Uuid(seed3)) === + evaluateWithUnsafeProjection(new Uuid(seed3))) } test("PrintToStderr") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index ace017b1cddc3..c74a9e35833d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -57,7 +57,7 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { exception = intercept[SparkRuntimeException] { evaluateWithoutCodegen(AssertNotNull(Literal(null))) }, - errorClass = "NOT_NULL_ASSERT_VIOLATION", + condition = "NOT_NULL_ASSERT_VIOLATION", sqlState = "42000", parameters = Map("walkedTypePath" -> "\n\n")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index af38fc1f12f7f..762a4e9166d51 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -82,7 +82,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { exception = intercept[SparkException] { Invoke(inputObject, "zeroArgNotExistMethod", IntegerType).eval(inputRow) }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> ("Couldn't find method zeroArgNotExistMethod with arguments " + "() on class java.lang.Object.") @@ -98,7 +98,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Seq(Literal.fromObject(UTF8String.fromString("dummyInputString"))), Seq(StringType)).eval(inputRow) }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> ("Couldn't find method oneArgNotExistMethod with arguments " + "(class org.apache.spark.unsafe.types.UTF8String) on class java.lang.Object.") @@ -417,7 +417,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { exception = intercept[SparkRuntimeException] { testMapObjects(collection, classOf[scala.collection.Map[Int, Int]], inputType) }, - errorClass = "CLASS_UNSUPPORTED_BY_MAP_OBJECTS", + condition = "CLASS_UNSUPPORTED_BY_MAP_OBJECTS", parameters = Map("cls" -> "scala.collection.Map")) } } @@ -588,7 +588,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { DoubleType, DoubleType), inputRow = InternalRow.fromSeq(Seq(Row(1))), - errorClass = "INVALID_EXTERNAL_TYPE", + condition = "INVALID_EXTERNAL_TYPE", parameters = Map[String, String]( "externalType" -> "java.lang.Integer", "type" -> "\"DOUBLE\"", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index 2aa53f581555f..2d58d9d3136aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types.{IntegerType, LongType} class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -41,4 +42,27 @@ class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { assert(Rand(Literal(1L), false).sql === "rand(1L)") assert(Randn(Literal(1L), false).sql === "randn(1L)") } + + test("SPARK-49505: Test the RANDSTR and UNIFORM SQL functions without codegen") { + // Note that we use a seed of zero in these tests to keep the results deterministic. + def testRandStr(first: Any, result: Any): Unit = { + checkEvaluationWithoutCodegen( + RandStr(Literal(first), Literal(0)), CatalystTypeConverters.convertToCatalyst(result)) + } + testRandStr(1, "c") + testRandStr(5, "ceV0P") + testRandStr(10, "ceV0PXaR2I") + testRandStr(10L, "ceV0PXaR2I") + + def testUniform(first: Any, second: Any, result: Any): Unit = { + checkEvaluationWithoutCodegen( + Uniform(Literal(first), Literal(second), Literal(0)).replacement, + CatalystTypeConverters.convertToCatalyst(result)) + } + testUniform(0, 1, 0) + testUniform(0, 10, 7) + testUniform(0L, 10L, 7L) + testUniform(10.0F, 20.0F, 17.604954F) + testUniform(10L, 20.0F, 17.604954F) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 690db18bbfa69..12aeb7d6685bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -158,13 +158,13 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { exception = intercept[AnalysisException] { evaluateWithoutCodegen("""a""" like """\a""") }, - errorClass = "INVALID_FORMAT.ESC_IN_THE_MIDDLE", + condition = "INVALID_FORMAT.ESC_IN_THE_MIDDLE", parameters = Map("format" -> """'\\a'""", "char" -> "'a'")) checkError( exception = intercept[AnalysisException] { evaluateWithoutCodegen("""a""" like """a\""") }, - errorClass = "INVALID_FORMAT.ESC_AT_THE_END", + condition = "INVALID_FORMAT.ESC_AT_THE_END", parameters = Map("format" -> """'a\\'""")) // case @@ -238,7 +238,7 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { exception = intercept[AnalysisException] { evaluateWithoutCodegen("""a""" like(s"""${escapeChar}a""", escapeChar)) }, - errorClass = "INVALID_FORMAT.ESC_IN_THE_MIDDLE", + condition = "INVALID_FORMAT.ESC_IN_THE_MIDDLE", parameters = Map("format" -> s"'${escapeChar}a'", "char" -> "'a'")) // case @@ -283,7 +283,7 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { exception = intercept[SparkRuntimeException] { evaluateWithoutCodegen("abbbbc" rlike "**") }, - errorClass = "INVALID_PARAMETER_VALUE.PATTERN", + condition = "INVALID_PARAMETER_VALUE.PATTERN", parameters = Map( "parameter" -> toSQLId("regexp"), "functionName" -> toSQLId("rlike"), @@ -294,7 +294,7 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val regex = $"a".string.at(0) evaluateWithoutCodegen("abbbbc" rlike regex, create_row("**")) }, - errorClass = "INVALID_PARAMETER_VALUE.PATTERN", + condition = "INVALID_PARAMETER_VALUE.PATTERN", parameters = Map( "parameter" -> toSQLId("regexp"), "functionName" -> toSQLId("rlike"), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index b39820f0d317d..29b878230472d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -155,7 +155,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { exception = intercept[AnalysisException] { Elt(Seq.empty).checkInputDataTypes() }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> "`elt`", "expectedNum" -> "> 1", @@ -166,7 +166,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { exception = intercept[AnalysisException] { Elt(Seq(Literal(1))).checkInputDataTypes() }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> "`elt`", "expectedNum" -> "> 1", @@ -1505,7 +1505,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkErrorInExpression[SparkIllegalArgumentException]( toNumberExpr, - errorClass = "INVALID_FORMAT.MISMATCH_INPUT", + condition = "INVALID_FORMAT.MISMATCH_INPUT", parameters = Map( "inputType" -> "\"STRING\"", "input" -> str, @@ -1910,7 +1910,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { exception = intercept[AnalysisException] { ParseUrl(Seq(Literal("1"))).checkInputDataTypes() }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> "`parse_url`", "expectedNum" -> "[2, 3]", @@ -1922,7 +1922,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ParseUrl(Seq(Literal("1"), Literal("2"), Literal("3"), Literal("4"))).checkInputDataTypes() }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> "`parse_url`", "expectedNum" -> "[2, 3]", @@ -1987,7 +1987,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // Test escaping of arguments GenerateUnsafeProjection.generate( - Sentences(Literal("\"quote"), Literal("\"quote"), Literal("\"quote")) :: Nil) + Sentences(Literal("\"quote"), Literal("\"quote"), Literal("\"quote")).replacement :: Nil) } test("SPARK-33386: elt ArrayIndexOutOfBoundsException") { @@ -2037,7 +2037,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { exception = intercept[AnalysisException] { expr1.checkInputDataTypes() }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> "`elt`", "expectedNum" -> "> 1", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala index f138d9642d1e1..446514de91d69 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala @@ -47,7 +47,7 @@ class TryCastSuite extends CastWithAnsiOnSuite { override def checkErrorInExpression[T <: SparkThrowable : ClassTag]( expression: => Expression, inputRow: InternalRow, - errorClass: String, + condition: String, parameters: Map[String, String]): Unit = { checkEvaluation(expression, null, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/XmlExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/XmlExpressionsSuite.scala index 9a10985153044..66baf6b1430fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/XmlExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/XmlExpressionsSuite.scala @@ -64,7 +64,7 @@ class XmlExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P }.getCause checkError( exception = exception.asInstanceOf[SparkException], - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map("badRecord" -> "[null]", "failFastMode" -> "FAILFAST") ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index 389c757eefb63..f1d9fd6a36584 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -410,7 +410,7 @@ class PercentileSuite extends SparkFunSuite { agg.update(buffer, InternalRow(1, -5)) agg.eval(buffer) }, - errorClass = "_LEGACY_ERROR_TEMP_2013", + condition = "_LEGACY_ERROR_TEMP_2013", parameters = Map("frequencyExpression" -> "CAST(boundreference() AS INT)")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala index 79f03f23eb245..ca2eaf7be0c21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala @@ -27,7 +27,7 @@ class BufferHolderSuite extends SparkFunSuite { exception = intercept[SparkUnsupportedOperationException] { new BufferHolder(new UnsafeRow(Int.MaxValue / 8)) }, - errorClass = "_LEGACY_ERROR_TEMP_3130", + condition = "_LEGACY_ERROR_TEMP_3130", parameters = Map("numFields" -> "268435455")) val holder = new BufferHolder(new UnsafeRow(1000)) @@ -38,7 +38,7 @@ class BufferHolderSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { holder.grow(Integer.MAX_VALUE) }, - errorClass = "_LEGACY_ERROR_TEMP_3199", + condition = "_LEGACY_ERROR_TEMP_3199", parameters = Map("neededSize" -> "2147483647", "arrayMax" -> "2147483632") ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala index d51647ee96df9..4f81ef49e5736 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -121,7 +121,7 @@ class CodeBlockSuite extends SparkFunSuite { exception = intercept[SparkException] { code"$obj" }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> s"Can not interpolate ${obj.getClass.getName} into code block.") ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriterSuite.scala index a968b1fe53506..8a8f0afeb1224 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriterSuite.scala @@ -30,7 +30,7 @@ class UnsafeArrayWriterSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { arrayWriter.initialize(numElements) }, - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.INITIALIZE", + condition = "COLLECTION_SIZE_LIMIT_EXCEEDED.INITIALIZE", parameters = Map( "numberOfElements" -> (numElements * elementSize).toString, "maxRoundedArrayLength" -> Int.MaxValue.toString diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtilsSuite.scala index 2b8c64a1af679..f599fead45015 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtilsSuite.scala @@ -116,7 +116,7 @@ class VariantExpressionEvalUtilsSuite extends SparkFunSuite { } test("parseJson negative") { - def checkException(json: String, errorClass: String, parameters: Map[String, String]): Unit = { + def checkException(json: String, condition: String, parameters: Map[String, String]): Unit = { val try_parse_json_output = VariantExpressionEvalUtils.parseJson(UTF8String.fromString(json), allowDuplicateKeys = false, failOnError = false) checkError( @@ -124,7 +124,7 @@ class VariantExpressionEvalUtilsSuite extends SparkFunSuite { VariantExpressionEvalUtils.parseJson(UTF8String.fromString(json), allowDuplicateKeys = false) }, - errorClass = errorClass, + condition = condition, parameters = parameters ) assert(try_parse_json_output === null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala index fb0bf63c01123..59392548a6ae2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala @@ -45,12 +45,14 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } test("to_json malformed") { - def check(value: Array[Byte], metadata: Array[Byte], - errorClass: String = "MALFORMED_VARIANT"): Unit = { + def check( + value: Array[Byte], + metadata: Array[Byte], + condition: String = "MALFORMED_VARIANT"): Unit = { checkErrorInExpression[SparkRuntimeException]( ResolveTimeZone.resolveTimeZones( StructsToJson(Map.empty, Literal(new VariantVal(value, metadata)))), - errorClass + condition ) } @@ -949,12 +951,24 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { ) } - test("cast to variant") { - def check[T : TypeTag](input: T, expectedJson: String): Unit = { - val cast = Cast(Literal.create(input), VariantType, evalMode = EvalMode.ANSI) - checkEvaluation(StructsToJson(Map.empty, cast), expectedJson) + test("cast to variant/to_variant_object") { + def check[T : TypeTag](input: T, expectedJson: String, + toVariantObject: Boolean = false): Unit = { + val expr = + if (toVariantObject) ToVariantObject(Literal.create(input)) + else Cast(Literal.create(input), VariantType, evalMode = EvalMode.ANSI) + checkEvaluation(StructsToJson(Map.empty, expr), expectedJson) } + def checkFailure[T: TypeTag](input: T, toVariantObject: Boolean = false): Unit = { + val expr = + if (toVariantObject) ToVariantObject(Literal.create(input)) + else Cast(Literal.create(input), VariantType, evalMode = EvalMode.ANSI) + val resolvedExpr = ResolveTimeZone.resolveTimeZones(expr) + assert(!resolvedExpr.resolved) + } + + // cast to variant - success cases check(null.asInstanceOf[String], null) // The following tests cover all allowed scalar types. for (input <- Seq[Any](false, true, 0.toByte, 1.toShort, 2, 3L, 4.0F, 5.0D)) { @@ -1023,17 +1037,52 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } check(Array(null, "a", "b", "c"), """[null,"a","b","c"]""") - check(Map("z" -> 1, "y" -> 2, "x" -> 3), """{"x":3,"y":2,"z":1}""") check(Array(parseJson("""{"a": 1,"b": [1, 2, 3]}"""), parseJson("""{"c": true,"d": {"e": "str"}}""")), """[{"a":1,"b":[1,2,3]},{"c":true,"d":{"e":"str"}}]""") - val struct = Literal.create( + + // cast to variant - failure cases - struct and map types + val mp = Map("z" -> 1, "y" -> 2, "x" -> 3) + val arrayMp = Array(Map("z" -> 1, "y" -> 2, "x" -> 3)) + val arrayArrayMp = Array(Array(Map("z" -> 1, "y" -> 2, "x" -> 3))) + checkFailure(mp) + checkFailure(arrayMp) + checkFailure(arrayArrayMp) + val struct = Literal.create(create_row(1), + StructType(Array(StructField("a", IntegerType)))) + checkFailure(struct) + val arrayStruct = Literal.create( + Array(create_row(1)), + ArrayType(StructType(Array(StructField("a", IntegerType))))) + checkFailure(arrayStruct) + + // to_variant_object - success cases - nested types + check(Array(1, 2, 3), "[1,2,3]", toVariantObject = true) + check(mp, """{"x":3,"y":2,"z":1}""", toVariantObject = true) + check(arrayMp, """[{"x":3,"y":2,"z":1}]""", toVariantObject = true) + check(arrayArrayMp, """[[{"x":3,"y":2,"z":1}]]""", toVariantObject = true) + check(struct, """{"a":1}""", toVariantObject = true) + check(arrayStruct, """[{"a":1}]""", toVariantObject = true) + val complexStruct = Literal.create( Row( Seq("123", "true", "f"), Map("a" -> "123", "b" -> "true", "c" -> "f"), + Map("a" -> Row(132)), Row(0)), - StructType.fromDDL("c ARRAY,b MAP,a STRUCT")) - check(struct, """{"a":{"i":0},"b":{"a":"123","b":"true","c":"f"},"c":["123","true","f"]}""") + StructType.fromDDL("c ARRAY,b MAP,d MAP>," + + "a STRUCT")) + check(complexStruct, + """{"a":{"i":0},"b":{"a":"123","b":"true","c":"f"},"c":["123","true","f"],""" + + """"d":{"a":{"i":132}}}""", + toVariantObject = true) + check(ymArrLit, """["INTERVAL '0' MONTH","INTERVAL""" + + """ '2147483647' MONTH","INTERVAL '-2147483647' MONTH"]""", toVariantObject = true) + + // to_variant_object - failure cases - non-nested types or map with non-string key + checkFailure(1, toVariantObject = true) + checkFailure(true, toVariantObject = true) + checkFailure(Literal.create(Literal.create(Period.ofMonths(0))), toVariantObject = true) + checkFailure(Map(1 -> 1), toVariantObject = true) } test("schema_of_variant - unknown type") { @@ -1092,7 +1141,7 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { val results = mutable.HashMap.empty[(Literal, Literal), String] for (i <- inputs) { - val inputType = if (i.value == null) "VOID" else i.dataType.sql + val inputType = if (i.value == null) "VOID" else SchemaOfVariant.printSchema(i.dataType) results.put((nul, i), inputType) results.put((i, i), inputType) } @@ -1106,14 +1155,24 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { results.put((timestamp, timestampNtz), "TIMESTAMP") results.put((float, decimal), "DOUBLE") results.put((array1, array2), "ARRAY") - results.put((struct1, struct2), "STRUCT") + results.put((struct1, struct2), "OBJECT") results.put((dtInterval1, dtInterval2), "INTERVAL DAY TO SECOND") results.put((ymInterval1, ymInterval2), "INTERVAL YEAR TO MONTH") for (i1 <- inputs) { for (i2 <- inputs) { val expected = results.getOrElse((i1, i2), results.getOrElse((i2, i1), "VARIANT")) - val array = CreateArray(Seq(Cast(i1, VariantType), Cast(i2, VariantType))) + val elem1 = + if (i1.dataType.isInstanceOf[ArrayType] || i1.dataType.isInstanceOf[MapType] || + i1.dataType.isInstanceOf[StructType]) { + ToVariantObject(i1) + } else Cast(i1, VariantType) + val elem2 = + if (i2.dataType.isInstanceOf[ArrayType] || i2.dataType.isInstanceOf[MapType] || + i2.dataType.isInstanceOf[StructType]) { + ToVariantObject(i2) + } else Cast(i2, VariantType) + val array = CreateArray(Seq(elem1, elem2)) checkEvaluation(SchemaOfVariant(Cast(array, VariantType)).replacement, s"ARRAY<$expected>") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExtractPythonUDFFromJoinConditionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExtractPythonUDFFromJoinConditionSuite.scala index 854a3e8f7a74d..776600bbdcf5d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExtractPythonUDFFromJoinConditionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExtractPythonUDFFromJoinConditionSuite.scala @@ -189,7 +189,7 @@ class ExtractPythonUDFFromJoinConditionSuite extends PlanTest { } checkError( exception = e, - errorClass = "UNSUPPORTED_FEATURE.PYTHON_UDF_IN_ON_CLAUSE", + condition = "UNSUPPORTED_FEATURE.PYTHON_UDF_IN_ON_CLAUSE", parameters = Map("joinType" -> joinType.sql) ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 5aeb27f7ee6b4..451236162343b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -27,12 +27,13 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Expand, Filter, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, MetadataBuilder} -class PropagateEmptyRelationSuite extends PlanTest { +class PropagateEmptyRelationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = - Batch("PropagateEmptyRelation", Once, + Batch("PropagateEmptyRelation", FixedPoint(1), CombineUnions, ReplaceDistinctWithAggregate, ReplaceExceptWithAntiJoin, @@ -45,7 +46,7 @@ class PropagateEmptyRelationSuite extends PlanTest { object OptimizeWithoutPropagateEmptyRelation extends RuleExecutor[LogicalPlan] { val batches = - Batch("OptimizeWithoutPropagateEmptyRelation", Once, + Batch("OptimizeWithoutPropagateEmptyRelation", FixedPoint(1), CombineUnions, ReplaceDistinctWithAggregate, ReplaceExceptWithAntiJoin, @@ -216,10 +217,24 @@ class PropagateEmptyRelationSuite extends PlanTest { .where($"a" =!= 200) .orderBy($"a".asc) - val optimized = Optimize.execute(query.analyze) - val correctAnswer = LocalRelation(output, isStreaming = true) + withSQLConf( + SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "true") { + val optimized = Optimize.execute(query.analyze) + val correctAnswer = LocalRelation(output, isStreaming = true) + comparePlans(optimized, correctAnswer) + } - comparePlans(optimized, correctAnswer) + withSQLConf( + SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "false") { + val optimized = Optimize.execute(query.analyze) + val correctAnswer = relation + .where(false) + .where($"a" > 1) + .select($"a") + .where($"a" =!= 200) + .orderBy($"a".asc).analyze + comparePlans(optimized, correctAnswer) + } } test("SPARK-47305 correctly tag isStreaming when propagating empty relation " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index b81a57f4f8cd5..66ded338340f3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -174,4 +174,38 @@ class PruneFiltersSuite extends PlanTest { testRelation.where(!$"a".attr.in(1, 3, 5) && $"a".attr === 7 && $"b".attr === 1) .where(Rand(10) > 0.1 && Rand(10) < 1.1).analyze) } + + test("Streaming relation is not lost under true filter") { + Seq("true", "false").foreach(x => withSQLConf( + SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> x) { + val streamingRelation = + LocalRelation(Seq($"a".int, $"b".int, $"c".int), Nil, isStreaming = true) + val originalQuery = streamingRelation.where(10 > 5).select($"a").analyze + val optimized = Optimize.execute(originalQuery) + val correctAnswer = streamingRelation.select($"a").analyze + comparePlans(optimized, correctAnswer) + }) + } + + test("Streaming relation is not lost under false filter") { + withSQLConf( + SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "true") { + val streamingRelation = + LocalRelation(Seq($"a".int, $"b".int, $"c".int), Nil, isStreaming = true) + val originalQuery = streamingRelation.where(10 < 5).select($"a").analyze + val optimized = Optimize.execute(originalQuery) + val correctAnswer = streamingRelation.select($"a").analyze + comparePlans(optimized, correctAnswer) + } + + withSQLConf( + SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "false") { + val streamingRelation = + LocalRelation(Seq($"a".int, $"b".int, $"c".int), Nil, isStreaming = true) + val originalQuery = streamingRelation.where(10 < 5).select($"a").analyze + val optimized = Optimize.execute(originalQuery) + val correctAnswer = streamingRelation.where(10 < 5).select($"a").analyze + comparePlans(optimized, correctAnswer) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReassignLambdaVariableIDSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReassignLambdaVariableIDSuite.scala index bf9f922978f6d..677a5d7928fc1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReassignLambdaVariableIDSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReassignLambdaVariableIDSuite.scala @@ -58,7 +58,7 @@ class ReassignLambdaVariableIDSuite extends PlanTest { val query = testRelation.where(var1 && var2) checkError( exception = intercept[SparkException](Optimize.execute(query)), - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map( "message" -> "LambdaVariable IDs in a query should be all positive or negative.")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index a50842a26b2ce..eaa651e62455e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -59,7 +59,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { exception = intercept[AnalysisException] { testFilter(originalCond = Literal(null, IntegerType), expectedCond = FalseLiteral) }, - errorClass = "DATATYPE_MISMATCH.FILTER_NOT_BOOLEAN", + condition = "DATATYPE_MISMATCH.FILTER_NOT_BOOLEAN", parameters = Map("sqlExpr" -> "\"NULL\"", "filter" -> "\"NULL\"", "type" -> "\"INT\"") ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index ac136dfb898ef..4d31999ded655 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.{Literal, Round} import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} @@ -109,4 +109,20 @@ class RewriteDistinctAggregatesSuite extends PlanTest { case _ => fail(s"Plan is not rewritten:\n$rewrite") } } + + test("SPARK-49261: Literals in grouping expressions shouldn't result in unresolved aggregation") { + val relation = testRelation2 + .select(Literal(6).as("gb"), $"a", $"b", $"c", $"d") + val input = relation + .groupBy($"a", $"gb")( + countDistinct($"b").as("agg1"), + countDistinct($"d").as("agg2"), + Round(sum($"c").as("sum1"), 6)).analyze + val rewriteFold = FoldablePropagation(input) + // without the fix, the below produces an unresolved plan + val rewrite = RewriteDistinctAggregates(rewriteFold) + if (!rewrite.resolved) { + fail(s"Plan is not as expected:\n$rewrite") + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index d514f777e5544..b7e2490b552cc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -20,14 +20,16 @@ package org.apache.spark.sql.catalyst.parser import java.util.Locale import org.apache.spark.SparkThrowable +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.{EqualTo, Hex, Literal} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.connector.catalog.IdentityColumnSpec import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition.{after, first} import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, ClusterByTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.connector.expressions.LogicalExpressions.bucket import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{Decimal, IntegerType, LongType, StringType, StructType, TimestampType} +import org.apache.spark.sql.types.{DataType, Decimal, IntegerType, LongType, StringType, StructType, TimestampType} import org.apache.spark.storage.StorageLevelMapper import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -79,7 +81,7 @@ class DDLParserSuite extends AnalysisTest { val sql = "CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING) USING parquet" checkError( exception = parseException(sql), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "':'", "hint" -> "")) } @@ -380,7 +382,7 @@ class DDLParserSuite extends AnalysisTest { |Columns: p2 string""".stripMargin checkError( exception = parseException(createSql), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> value1), context = ExpectedContext( fragment = createSql, @@ -390,7 +392,7 @@ class DDLParserSuite extends AnalysisTest { val replaceSql = createSql.replaceFirst("CREATE", "REPLACE") checkError( exception = parseException(replaceSql), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> value1), context = ExpectedContext( fragment = replaceSql, @@ -405,7 +407,7 @@ class DDLParserSuite extends AnalysisTest { |Columns: p2 string""".stripMargin checkError( exception = parseException(createSqlWithExpr), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> value2), context = ExpectedContext( fragment = createSqlWithExpr, @@ -415,7 +417,7 @@ class DDLParserSuite extends AnalysisTest { val replaceSqlWithExpr = createSqlWithExpr.replaceFirst("CREATE", "REPLACE") checkError( exception = parseException(replaceSqlWithExpr), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> value2), context = ExpectedContext( fragment = replaceSqlWithExpr, @@ -482,7 +484,7 @@ class DDLParserSuite extends AnalysisTest { "which also specifies a serde" checkError( exception = parseException(createSql), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> value), context = ExpectedContext( fragment = createSql, @@ -492,7 +494,7 @@ class DDLParserSuite extends AnalysisTest { val replaceSql = createSql.replaceFirst("CREATE", "REPLACE") checkError( exception = parseException(replaceSql), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> value), context = ExpectedContext( fragment = replaceSql, @@ -538,7 +540,7 @@ class DDLParserSuite extends AnalysisTest { val value = "ROW FORMAT DELIMITED is only compatible with 'textfile', not 'otherformat'" checkError( exception = parseException(createFailSql), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> value), context = ExpectedContext( fragment = createFailSql, @@ -548,7 +550,7 @@ class DDLParserSuite extends AnalysisTest { val replaceFailSql = createFailSql.replaceFirst("CREATE", "REPLACE") checkError( exception = parseException(replaceFailSql), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> value), context = ExpectedContext( fragment = replaceFailSql, @@ -610,7 +612,7 @@ class DDLParserSuite extends AnalysisTest { |STORED AS parquet""".stripMargin checkError( exception = parseException(createSql), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "CREATE TABLE ... USING ... STORED AS PARQUET "), context = ExpectedContext( fragment = createSql, @@ -620,7 +622,7 @@ class DDLParserSuite extends AnalysisTest { val replaceSql = createSql.replaceFirst("CREATE", "REPLACE") checkError( exception = parseException(replaceSql), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "REPLACE TABLE ... USING ... STORED AS PARQUET "), context = ExpectedContext( fragment = replaceSql, @@ -635,7 +637,7 @@ class DDLParserSuite extends AnalysisTest { |ROW FORMAT SERDE 'customSerde'""".stripMargin checkError( exception = parseException(createSql), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "CREATE TABLE ... USING ... ROW FORMAT SERDE CUSTOMSERDE"), context = ExpectedContext( fragment = createSql, @@ -645,7 +647,7 @@ class DDLParserSuite extends AnalysisTest { val replaceSql = createSql.replaceFirst("CREATE", "REPLACE") checkError( exception = parseException(replaceSql), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "REPLACE TABLE ... USING ... ROW FORMAT SERDE CUSTOMSERDE"), context = ExpectedContext( fragment = replaceSql, @@ -660,7 +662,7 @@ class DDLParserSuite extends AnalysisTest { |ROW FORMAT DELIMITED FIELDS TERMINATED BY ','""".stripMargin checkError( exception = parseException(createSql), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "CREATE TABLE ... USING ... ROW FORMAT DELIMITED"), context = ExpectedContext( fragment = createSql, @@ -670,7 +672,7 @@ class DDLParserSuite extends AnalysisTest { val replaceSql = createSql.replaceFirst("CREATE", "REPLACE") checkError( exception = parseException(replaceSql), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "REPLACE TABLE ... USING ... ROW FORMAT DELIMITED"), context = ExpectedContext( fragment = replaceSql, @@ -685,7 +687,7 @@ class DDLParserSuite extends AnalysisTest { val fragment = "STORED BY 'handler'" checkError( exception = parseException(createSql), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "STORED BY"), context = ExpectedContext( fragment = fragment, @@ -695,7 +697,7 @@ class DDLParserSuite extends AnalysisTest { val replaceSql = createSql.replaceFirst("CREATE", "REPLACE") checkError( exception = parseException(replaceSql), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "STORED BY"), context = ExpectedContext( fragment = fragment, @@ -707,7 +709,7 @@ class DDLParserSuite extends AnalysisTest { val sql1 = "CREATE TABLE my_tab (id bigint) SKEWED BY (id) ON (1,2,3)" checkError( exception = parseException(sql1), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "CREATE TABLE ... SKEWED BY"), context = ExpectedContext( fragment = sql1, @@ -717,7 +719,7 @@ class DDLParserSuite extends AnalysisTest { val sql2 = "REPLACE TABLE my_tab (id bigint) SKEWED BY (id) ON (1,2,3)" checkError( exception = parseException(sql2), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "CREATE TABLE ... SKEWED BY"), context = ExpectedContext( fragment = sql2, @@ -737,7 +739,7 @@ class DDLParserSuite extends AnalysisTest { val sql1 = createTableHeader("TBLPROPERTIES('test' = 'test2')") checkError( exception = parseException(sql1), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "TBLPROPERTIES"), context = ExpectedContext( fragment = sql1, @@ -747,7 +749,7 @@ class DDLParserSuite extends AnalysisTest { val sql2 = createTableHeader("LOCATION '/tmp/file'") checkError( exception = parseException(sql2), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "LOCATION"), context = ExpectedContext( fragment = sql2, @@ -757,7 +759,7 @@ class DDLParserSuite extends AnalysisTest { val sql3 = createTableHeader("COMMENT 'a table'") checkError( exception = parseException(sql3), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "COMMENT"), context = ExpectedContext( fragment = sql3, @@ -767,7 +769,7 @@ class DDLParserSuite extends AnalysisTest { val sql4 = createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS") checkError( exception = parseException(sql4), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "CLUSTERED BY"), context = ExpectedContext( fragment = sql4, @@ -777,7 +779,7 @@ class DDLParserSuite extends AnalysisTest { val sql5 = createTableHeader("PARTITIONED BY (b)") checkError( exception = parseException(sql5), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "PARTITIONED BY"), context = ExpectedContext( fragment = sql5, @@ -787,7 +789,7 @@ class DDLParserSuite extends AnalysisTest { val sql6 = createTableHeader("PARTITIONED BY (c int)") checkError( exception = parseException(sql6), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "PARTITIONED BY"), context = ExpectedContext( fragment = sql6, @@ -797,7 +799,7 @@ class DDLParserSuite extends AnalysisTest { val sql7 = createTableHeader("STORED AS parquet") checkError( exception = parseException(sql7), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "STORED AS/BY"), context = ExpectedContext( fragment = sql7, @@ -807,7 +809,7 @@ class DDLParserSuite extends AnalysisTest { val sql8 = createTableHeader("STORED AS INPUTFORMAT 'in' OUTPUTFORMAT 'out'") checkError( exception = parseException(sql8), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "STORED AS/BY"), context = ExpectedContext( fragment = sql8, @@ -817,7 +819,7 @@ class DDLParserSuite extends AnalysisTest { val sql9 = createTableHeader("ROW FORMAT SERDE 'serde'") checkError( exception = parseException(sql9), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "ROW FORMAT"), context = ExpectedContext( fragment = sql9, @@ -827,7 +829,7 @@ class DDLParserSuite extends AnalysisTest { val sql10 = replaceTableHeader("TBLPROPERTIES('test' = 'test2')") checkError( exception = parseException(sql10), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "TBLPROPERTIES"), context = ExpectedContext( fragment = sql10, @@ -837,7 +839,7 @@ class DDLParserSuite extends AnalysisTest { val sql11 = replaceTableHeader("LOCATION '/tmp/file'") checkError( exception = parseException(sql11), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "LOCATION"), context = ExpectedContext( fragment = sql11, @@ -847,7 +849,7 @@ class DDLParserSuite extends AnalysisTest { val sql12 = replaceTableHeader("COMMENT 'a table'") checkError( exception = parseException(sql12), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "COMMENT"), context = ExpectedContext( fragment = sql12, @@ -857,7 +859,7 @@ class DDLParserSuite extends AnalysisTest { val sql13 = replaceTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS") checkError( exception = parseException(sql13), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "CLUSTERED BY"), context = ExpectedContext( fragment = sql13, @@ -867,7 +869,7 @@ class DDLParserSuite extends AnalysisTest { val sql14 = replaceTableHeader("PARTITIONED BY (b)") checkError( exception = parseException(sql14), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "PARTITIONED BY"), context = ExpectedContext( fragment = sql14, @@ -877,7 +879,7 @@ class DDLParserSuite extends AnalysisTest { val sql15 = replaceTableHeader("PARTITIONED BY (c int)") checkError( exception = parseException(sql15), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "PARTITIONED BY"), context = ExpectedContext( fragment = sql15, @@ -887,7 +889,7 @@ class DDLParserSuite extends AnalysisTest { val sql16 = replaceTableHeader("STORED AS parquet") checkError( exception = parseException(sql16), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "STORED AS/BY"), context = ExpectedContext( fragment = sql16, @@ -897,7 +899,7 @@ class DDLParserSuite extends AnalysisTest { val sql17 = replaceTableHeader("STORED AS INPUTFORMAT 'in' OUTPUTFORMAT 'out'") checkError( exception = parseException(sql17), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "STORED AS/BY"), context = ExpectedContext( fragment = sql17, @@ -907,7 +909,7 @@ class DDLParserSuite extends AnalysisTest { val sql18 = replaceTableHeader("ROW FORMAT SERDE 'serde'") checkError( exception = parseException(sql18), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "ROW FORMAT"), context = ExpectedContext( fragment = sql18, @@ -917,7 +919,7 @@ class DDLParserSuite extends AnalysisTest { val sql19 = createTableHeader("CLUSTER BY (a)") checkError( exception = parseException(sql19), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "CLUSTER BY"), context = ExpectedContext( fragment = sql19, @@ -927,7 +929,7 @@ class DDLParserSuite extends AnalysisTest { val sql20 = replaceTableHeader("CLUSTER BY (a)") checkError( exception = parseException(sql20), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "CLUSTER BY"), context = ExpectedContext( fragment = sql20, @@ -1231,7 +1233,7 @@ class DDLParserSuite extends AnalysisTest { val fragment = "bad_type" checkError( exception = parseException(sql), - errorClass = "UNSUPPORTED_DATATYPE", + condition = "UNSUPPORTED_DATATYPE", parameters = Map("typeName" -> "\"BAD_TYPE\""), context = ExpectedContext( fragment = fragment, @@ -1282,19 +1284,19 @@ class DDLParserSuite extends AnalysisTest { val sql1 = "ALTER TABLE table_name ALTER COLUMN a.b.c TYPE bigint COMMENT 'new comment'" checkError( exception = parseException(sql1), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'COMMENT'", "hint" -> "")) val sql2 = "ALTER TABLE table_name ALTER COLUMN a.b.c TYPE bigint COMMENT AFTER d" checkError( exception = parseException(sql2), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'COMMENT'", "hint" -> "")) val sql3 = "ALTER TABLE table_name ALTER COLUMN a.b.c TYPE bigint COMMENT 'new comment' AFTER d" checkError( exception = parseException(sql3), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'COMMENT'", "hint" -> "")) } @@ -1364,7 +1366,7 @@ class DDLParserSuite extends AnalysisTest { val sql4 = "ALTER TABLE table_name CHANGE COLUMN a.b.c new_name INT" checkError( exception = parseException(sql4), - errorClass = "_LEGACY_ERROR_TEMP_0034", + condition = "_LEGACY_ERROR_TEMP_0034", parameters = Map( "operation" -> "Renaming column", "command" -> "ALTER COLUMN", @@ -1378,7 +1380,7 @@ class DDLParserSuite extends AnalysisTest { val sql5 = "ALTER TABLE table_name PARTITION (a='1') CHANGE COLUMN a.b.c c INT" checkError( exception = parseException(sql5), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE ... PARTITION ... CHANGE COLUMN"), context = ExpectedContext( fragment = sql5, @@ -1425,7 +1427,7 @@ class DDLParserSuite extends AnalysisTest { val sql5 = "ALTER TABLE table_name PARTITION (a='1') REPLACE COLUMNS (x string)" checkError( exception = parseException(sql5), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE ... PARTITION ... REPLACE COLUMNS"), context = ExpectedContext( fragment = sql5, @@ -1435,7 +1437,7 @@ class DDLParserSuite extends AnalysisTest { val sql6 = "ALTER TABLE table_name REPLACE COLUMNS (x string NOT NULL)" checkError( exception = parseException(sql6), - errorClass = "_LEGACY_ERROR_TEMP_0034", + condition = "_LEGACY_ERROR_TEMP_0034", parameters = Map("operation" -> "NOT NULL", "command" -> "REPLACE COLUMNS", "msg" -> ""), context = ExpectedContext( fragment = sql6, @@ -1445,7 +1447,7 @@ class DDLParserSuite extends AnalysisTest { val sql7 = "ALTER TABLE table_name REPLACE COLUMNS (x string FIRST)" checkError( exception = parseException(sql7), - errorClass = "_LEGACY_ERROR_TEMP_0034", + condition = "_LEGACY_ERROR_TEMP_0034", parameters = Map( "operation" -> "Column position", "command" -> "REPLACE COLUMNS", @@ -1458,7 +1460,7 @@ class DDLParserSuite extends AnalysisTest { val sql8 = "ALTER TABLE table_name REPLACE COLUMNS (a.b.c string)" checkError( exception = parseException(sql8), - errorClass = "_LEGACY_ERROR_TEMP_0034", + condition = "_LEGACY_ERROR_TEMP_0034", parameters = Map( "operation" -> "Replacing with a nested column", "command" -> "REPLACE COLUMNS", @@ -1471,7 +1473,7 @@ class DDLParserSuite extends AnalysisTest { val sql9 = "ALTER TABLE table_name REPLACE COLUMNS (a STRING COMMENT 'x' COMMENT 'y')" checkError( exception = parseException(sql9), - errorClass = "ALTER_TABLE_COLUMN_DESCRIPTOR_DUPLICATE", + condition = "ALTER_TABLE_COLUMN_DESCRIPTOR_DUPLICATE", parameters = Map( "type" -> "REPLACE", "columnName" -> "a", @@ -1646,7 +1648,7 @@ class DDLParserSuite extends AnalysisTest { |PARTITION (p1 = 3, p2) IF NOT EXISTS""".stripMargin checkError( exception = parseException(sql), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> "IF NOT EXISTS with dynamic partitions: p2"), context = ExpectedContext( fragment = fragment, @@ -1664,7 +1666,7 @@ class DDLParserSuite extends AnalysisTest { |PARTITION (p1 = 3) IF NOT EXISTS""".stripMargin checkError( exception = parseException(sql), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "INSERT INTO ... IF NOT EXISTS"), context = ExpectedContext( fragment = fragment, @@ -1704,7 +1706,7 @@ class DDLParserSuite extends AnalysisTest { checkError( exception = parseException( "INSERT INTO TABLE t1 BY NAME (c1,c2) SELECT * FROM tmp_view"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map( "error" -> "'c1'", "hint" -> "") @@ -1729,7 +1731,7 @@ class DDLParserSuite extends AnalysisTest { val sql = "DELETE FROM testcat.ns1.ns2.tbl AS t(a,b,c,d) WHERE d = 2" checkError( exception = parseException(sql), - errorClass = "COLUMN_ALIASES_NOT_ALLOWED", + condition = "COLUMN_ALIASES_NOT_ALLOWED", parameters = Map("op" -> "DELETE"), context = ExpectedContext( fragment = sql, @@ -1771,7 +1773,7 @@ class DDLParserSuite extends AnalysisTest { |WHERE d=2""".stripMargin checkError( exception = parseException(sql), - errorClass = "COLUMN_ALIASES_NOT_ALLOWED", + condition = "COLUMN_ALIASES_NOT_ALLOWED", parameters = Map("op" -> "UPDATE"), context = ExpectedContext( fragment = sql, @@ -1931,7 +1933,7 @@ class DDLParserSuite extends AnalysisTest { """.stripMargin checkError( exception = parseException(sql), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'*'", "hint" -> "")) } @@ -1977,7 +1979,7 @@ class DDLParserSuite extends AnalysisTest { .stripMargin checkError( exception = parseException(sql), - errorClass = "COLUMN_ALIASES_NOT_ALLOWED", + condition = "COLUMN_ALIASES_NOT_ALLOWED", parameters = Map("op" -> "MERGE"), context = ExpectedContext( fragment = sql, @@ -2056,7 +2058,7 @@ class DDLParserSuite extends AnalysisTest { |THEN INSERT (target.col1, target.col2) values (source.col1, source.col2)""".stripMargin checkError( exception = parseException(sql), - errorClass = "NON_LAST_MATCHED_CLAUSE_OMIT_CONDITION", + condition = "NON_LAST_MATCHED_CLAUSE_OMIT_CONDITION", parameters = Map.empty, context = ExpectedContext( fragment = sql, @@ -2079,7 +2081,7 @@ class DDLParserSuite extends AnalysisTest { |THEN INSERT (target.col1, target.col2) values (source.col1, source.col2)""".stripMargin checkError( exception = parseException(sql), - errorClass = "NON_LAST_NOT_MATCHED_BY_TARGET_CLAUSE_OMIT_CONDITION", + condition = "NON_LAST_NOT_MATCHED_BY_TARGET_CLAUSE_OMIT_CONDITION", parameters = Map.empty, context = ExpectedContext( fragment = sql, @@ -2103,7 +2105,7 @@ class DDLParserSuite extends AnalysisTest { |WHEN NOT MATCHED BY SOURCE THEN DELETE""".stripMargin checkError( exception = parseException(sql), - errorClass = "NON_LAST_NOT_MATCHED_BY_SOURCE_CLAUSE_OMIT_CONDITION", + condition = "NON_LAST_NOT_MATCHED_BY_SOURCE_CLAUSE_OMIT_CONDITION", parameters = Map.empty, context = ExpectedContext( fragment = sql, @@ -2118,7 +2120,7 @@ class DDLParserSuite extends AnalysisTest { |ON target.col1 = source.col1""".stripMargin checkError( exception = parseException(sql), - errorClass = "_LEGACY_ERROR_TEMP_0008", + condition = "_LEGACY_ERROR_TEMP_0008", parameters = Map.empty, context = ExpectedContext( fragment = sql, @@ -2209,7 +2211,7 @@ class DDLParserSuite extends AnalysisTest { val sql1 = "analyze table a.b.c compute statistics xxxx" checkError( exception = parseException(sql1), - errorClass = "INVALID_SQL_SYNTAX.ANALYZE_TABLE_UNEXPECTED_NOSCAN", + condition = "INVALID_SQL_SYNTAX.ANALYZE_TABLE_UNEXPECTED_NOSCAN", parameters = Map("ctx" -> "XXXX"), context = ExpectedContext( fragment = sql1, @@ -2219,7 +2221,7 @@ class DDLParserSuite extends AnalysisTest { val sql2 = "analyze table a.b.c partition (a) compute statistics xxxx" checkError( exception = parseException(sql2), - errorClass = "INVALID_SQL_SYNTAX.ANALYZE_TABLE_UNEXPECTED_NOSCAN", + condition = "INVALID_SQL_SYNTAX.ANALYZE_TABLE_UNEXPECTED_NOSCAN", parameters = Map("ctx" -> "XXXX"), context = ExpectedContext( fragment = sql2, @@ -2238,7 +2240,7 @@ class DDLParserSuite extends AnalysisTest { val sql = "ANALYZE TABLES IN a.b.c COMPUTE STATISTICS xxxx" checkError( exception = parseException(sql), - errorClass = "INVALID_SQL_SYNTAX.ANALYZE_TABLE_UNEXPECTED_NOSCAN", + condition = "INVALID_SQL_SYNTAX.ANALYZE_TABLE_UNEXPECTED_NOSCAN", parameters = Map("ctx" -> "XXXX"), context = ExpectedContext( fragment = sql, @@ -2250,7 +2252,7 @@ class DDLParserSuite extends AnalysisTest { val sql1 = "ANALYZE TABLE a.b.c COMPUTE STATISTICS FOR COLUMNS" checkError( exception = parseException(sql1), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "end of input", "hint" -> "")) comparePlans( @@ -2287,13 +2289,13 @@ class DDLParserSuite extends AnalysisTest { val sql2 = "ANALYZE TABLE a.b.c COMPUTE STATISTICS FOR ALL COLUMNS key, value" checkError( exception = parseException(sql2), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'key'", "hint" -> "")) // expecting {, ';'} val sql3 = "ANALYZE TABLE a.b.c COMPUTE STATISTICS FOR ALL" checkError( exception = parseException(sql3), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "end of input", "hint" -> ": missing 'COLUMNS'")) } @@ -2370,7 +2372,7 @@ class DDLParserSuite extends AnalysisTest { val sql = "CACHE TABLE a.b.c AS SELECT * FROM testData" checkError( exception = parseException(sql), - errorClass = "_LEGACY_ERROR_TEMP_0037", + condition = "_LEGACY_ERROR_TEMP_0037", parameters = Map("quoted" -> "a.b"), context = ExpectedContext( fragment = sql, @@ -2382,7 +2384,7 @@ class DDLParserSuite extends AnalysisTest { val createTableSql = "create table test_table using my_data_source options (password)" checkError( exception = parseException(createTableSql), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> "A value must be specified for the key: password."), context = ExpectedContext( fragment = createTableSql, @@ -2413,7 +2415,7 @@ class DDLParserSuite extends AnalysisTest { |(dt='2009-09-09', country='uk')""".stripMargin checkError( exception = parseException(sql), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER VIEW ... ADD PARTITION"), context = ExpectedContext( fragment = sql, @@ -2698,14 +2700,14 @@ class DDLParserSuite extends AnalysisTest { val sql1 = "ALTER TABLE t1 ALTER COLUMN a.b.c SET DEFAULT " checkError( exception = parseException(sql1), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "end of input", "hint" -> "")) // It is not possible to both SET DEFAULT and DROP DEFAULT at the same time. // This results in a parsing error. val sql2 = "ALTER TABLE t1 ALTER COLUMN a.b.c DROP DEFAULT SET DEFAULT 42" checkError( exception = parseException(sql2), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'SET'", "hint" -> "")) comparePlans( @@ -2724,7 +2726,7 @@ class DDLParserSuite extends AnalysisTest { val fragment = "b STRING NOT NULL DEFAULT \"abc\"" checkError( exception = parseException(sql), - errorClass = "UNSUPPORTED_DEFAULT_VALUE.WITH_SUGGESTION", + condition = "UNSUPPORTED_DEFAULT_VALUE.WITH_SUGGESTION", parameters = Map.empty, context = ExpectedContext( fragment = fragment, @@ -2784,7 +2786,7 @@ class DDLParserSuite extends AnalysisTest { exception = intercept[ParseException]( parsePlan( "CREATE TABLE my_tab(a INT, b STRING NOT NULL DEFAULT \"abc\" NOT NULL)")), - errorClass = "CREATE_TABLE_COLUMN_DESCRIPTOR_DUPLICATE", + condition = "CREATE_TABLE_COLUMN_DESCRIPTOR_DUPLICATE", parameters = Map( "columnName" -> "b", "optionName" -> "NOT NULL"), @@ -2794,7 +2796,7 @@ class DDLParserSuite extends AnalysisTest { exception = intercept[ParseException]( parsePlan( "CREATE TABLE my_tab(a INT, b STRING DEFAULT \"123\" NOT NULL DEFAULT \"abc\")")), - errorClass = "CREATE_TABLE_COLUMN_DESCRIPTOR_DUPLICATE", + condition = "CREATE_TABLE_COLUMN_DESCRIPTOR_DUPLICATE", parameters = Map( "columnName" -> "b", "optionName" -> "DEFAULT"), @@ -2804,7 +2806,7 @@ class DDLParserSuite extends AnalysisTest { exception = intercept[ParseException]( parsePlan( "CREATE TABLE my_tab(a INT, b STRING COMMENT \"abc\" NOT NULL COMMENT \"abc\")")), - errorClass = "CREATE_TABLE_COLUMN_DESCRIPTOR_DUPLICATE", + condition = "CREATE_TABLE_COLUMN_DESCRIPTOR_DUPLICATE", parameters = Map( "columnName" -> "b", "optionName" -> "COMMENT"), @@ -2836,7 +2838,7 @@ class DDLParserSuite extends AnalysisTest { checkError( exception = parseException("CREATE TABLE my_tab(a INT, " + "b INT GENERATED ALWAYS AS (a + 1) GENERATED ALWAYS AS (a + 2)) USING PARQUET"), - errorClass = "CREATE_TABLE_COLUMN_DESCRIPTOR_DUPLICATE", + condition = "CREATE_TABLE_COLUMN_DESCRIPTOR_DUPLICATE", parameters = Map("columnName" -> "b", "optionName" -> "GENERATED ALWAYS AS"), context = ExpectedContext( fragment = "b INT GENERATED ALWAYS AS (a + 1) GENERATED ALWAYS AS (a + 2)", @@ -2848,18 +2850,225 @@ class DDLParserSuite extends AnalysisTest { checkError( exception = parseException( "CREATE TABLE my_tab(a INT, b INT GENERATED ALWAYS AS ()) USING PARQUET"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "')'", "hint" -> "") ) // No parenthesis checkError( exception = parseException( "CREATE TABLE my_tab(a INT, b INT GENERATED ALWAYS AS a + 1) USING PARQUET"), - errorClass = "PARSE_SYNTAX_ERROR", - parameters = Map("error" -> "'a'", "hint" -> ": missing '('") + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'a'", "hint" -> "") ) } + test("SPARK-48824: implement parser support for " + + "GENERATED ALWAYS/BY DEFAULT AS IDENTITY columns in tables ") { + def parseAndCompareIdentityColumnPlan( + identityColumnDataTypeStr: String, + identityColumnDefStr: String, + identityColumnSpecStr: String, + expectedDataType: DataType, + expectedStart: Long, + expectedStep: Long, + expectedAllowExplicitInsert: Boolean): Unit = { + val columnsWithIdentitySpec = Seq( + ColumnDefinition( + name = "id", + dataType = expectedDataType, + nullable = true, + identityColumnSpec = Some( + new IdentityColumnSpec( + expectedStart, + expectedStep, + expectedAllowExplicitInsert + ) + ) + ), + ColumnDefinition("val", IntegerType) + ) + comparePlans( + parsePlan( + s"CREATE TABLE my_tab(id $identityColumnDataTypeStr GENERATED $identityColumnDefStr" + + s" AS IDENTITY $identityColumnSpecStr, val INT) USING parquet" + ), + CreateTable( + UnresolvedIdentifier(Seq("my_tab")), + columnsWithIdentitySpec, + Seq.empty[Transform], + UnresolvedTableSpec( + Map.empty[String, String], + Some("parquet"), + OptionList(Seq.empty), + None, + None, + None, + false + ), + false + ) + ) + + comparePlans( + parsePlan( + s"REPLACE TABLE my_tab(id $identityColumnDataTypeStr GENERATED $identityColumnDefStr" + + s" AS IDENTITY $identityColumnSpecStr, val INT) USING parquet" + ), + ReplaceTable( + UnresolvedIdentifier(Seq("my_tab")), + columnsWithIdentitySpec, + Seq.empty[Transform], + UnresolvedTableSpec( + Map.empty[String, String], + Some("parquet"), + OptionList(Seq.empty), + None, + None, + None, + false + ), + false + ) + ) + } + for { + identityColumnDefStr <- Seq("BY DEFAULT", "ALWAYS") + identityColumnDataTypeStr <- Seq("BIGINT", "INT") + } { + val expectedAllowExplicitInsert = identityColumnDefStr == "BY DEFAULT" + val expectedDataType = identityColumnDataTypeStr match { + case "BIGINT" => LongType + case "INT" => IntegerType + } + parseAndCompareIdentityColumnPlan( + identityColumnDataTypeStr, + identityColumnDefStr, + "(START WITH 2 INCREMENT BY 2)", + expectedDataType, + expectedStart = 2, + expectedStep = 2, + expectedAllowExplicitInsert = expectedAllowExplicitInsert) + parseAndCompareIdentityColumnPlan( + identityColumnDataTypeStr, + identityColumnDefStr, + "(START WITH -2 INCREMENT BY -2)", + expectedDataType, + expectedStart = -2, + expectedStep = -2, + expectedAllowExplicitInsert = expectedAllowExplicitInsert) + parseAndCompareIdentityColumnPlan( + identityColumnDataTypeStr, + identityColumnDefStr, + "(START WITH 2)", + expectedDataType, + expectedStart = 2, + expectedStep = 1, + expectedAllowExplicitInsert = expectedAllowExplicitInsert) + parseAndCompareIdentityColumnPlan( + identityColumnDataTypeStr, + identityColumnDefStr, + "(START WITH -2)", + expectedDataType, + expectedStart = -2, + expectedStep = 1, + expectedAllowExplicitInsert = expectedAllowExplicitInsert) + parseAndCompareIdentityColumnPlan( + identityColumnDataTypeStr, + identityColumnDefStr, + "(INCREMENT BY 2)", + expectedDataType, + expectedStart = 1, + expectedStep = 2, + expectedAllowExplicitInsert = expectedAllowExplicitInsert) + parseAndCompareIdentityColumnPlan( + identityColumnDataTypeStr, + identityColumnDefStr, + "(INCREMENT BY -2)", + expectedDataType, + expectedStart = 1, + expectedStep = -2, + expectedAllowExplicitInsert = expectedAllowExplicitInsert) + parseAndCompareIdentityColumnPlan( + identityColumnDataTypeStr, + identityColumnDefStr, + "()", + expectedDataType, + expectedStart = 1, + expectedStep = 1, + expectedAllowExplicitInsert = expectedAllowExplicitInsert) + parseAndCompareIdentityColumnPlan( + identityColumnDataTypeStr, + identityColumnDefStr, + "", + expectedDataType, + expectedStart = 1, + expectedStep = 1, + expectedAllowExplicitInsert = expectedAllowExplicitInsert) + } + } + + test("SPARK-48824: Column cannot have both a generation expression and an identity column spec") { + checkError( + exception = intercept[AnalysisException] { + parsePlan(s"CREATE TABLE testcat.my_tab(id BIGINT GENERATED ALWAYS AS 1" + + s" GENERATED ALWAYS AS IDENTITY, val INT) USING foo") + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'1'", "hint" -> "") + ) + } + + test("SPARK-48824: Identity column step must not be zero") { + checkError( + exception = intercept[ParseException] { + parsePlan( + s"CREATE TABLE testcat.my_tab" + + s"(id BIGINT GENERATED ALWAYS AS IDENTITY(INCREMENT BY 0), val INT) USING foo" + ) + }, + condition = "IDENTITY_COLUMNS_ILLEGAL_STEP", + parameters = Map.empty, + context = ExpectedContext( + fragment = "id BIGINT GENERATED ALWAYS AS IDENTITY(INCREMENT BY 0)", + start = 28, + stop = 81) + ) + } + + test("SPARK-48824: Identity column datatype must be long or integer") { + checkError( + exception = intercept[ParseException] { + parsePlan( + s"CREATE TABLE testcat.my_tab(id FLOAT GENERATED ALWAYS AS IDENTITY(), val INT) USING foo" + ) + }, + condition = "IDENTITY_COLUMNS_UNSUPPORTED_DATA_TYPE", + parameters = Map("dataType" -> "FloatType"), + context = + ExpectedContext(fragment = "id FLOAT GENERATED ALWAYS AS IDENTITY()", start = 28, stop = 66) + ) + } + + test("SPARK-48824: Identity column sequence generator option cannot be duplicated") { + val identityColumnSpecStrs = Seq( + "(START WITH 0 START WITH 1)", + "(INCREMENT BY 1 INCREMENT BY 2)", + "(START WITH 0 INCREMENT BY 1 START WITH 1)", + "(INCREMENT BY 1 START WITH 0 INCREMENT BY 2)" + ) + for { + identitySpecStr <- identityColumnSpecStrs + } { + val exception = intercept[ParseException] { + parsePlan( + s"CREATE TABLE testcat.my_tab" + + s"(id BIGINT GENERATED ALWAYS AS IDENTITY $identitySpecStr, val INT) USING foo" + ) + } + assert(exception.getErrorClass === "IDENTITY_COLUMNS_DUPLICATED_SEQUENCE_GENERATOR_OPTION") + } + } + test("SPARK-42681: Relax ordering constraint for ALTER TABLE ADD COLUMN options") { // Positive test cases to verify that column definition options could be applied in any order. val expectedPlan = AddColumns( @@ -2887,7 +3096,7 @@ class DDLParserSuite extends AnalysisTest { exception = intercept[ParseException]( parsePlan("ALTER TABLE my_tab ADD COLUMN b STRING NOT NULL DEFAULT \"abc\" NOT NULL") ), - errorClass = "ALTER_TABLE_COLUMN_DESCRIPTOR_DUPLICATE", + condition = "ALTER_TABLE_COLUMN_DESCRIPTOR_DUPLICATE", parameters = Map("type" -> "ADD", "columnName" -> "b", "optionName" -> "NOT NULL"), context = ExpectedContext( fragment = "b STRING NOT NULL DEFAULT \"abc\" NOT NULL", @@ -2899,7 +3108,7 @@ class DDLParserSuite extends AnalysisTest { exception = intercept[ParseException]( parsePlan("ALTER TABLE my_tab ADD COLUMN b STRING DEFAULT \"123\" NOT NULL DEFAULT \"abc\"") ), - errorClass = "ALTER_TABLE_COLUMN_DESCRIPTOR_DUPLICATE", + condition = "ALTER_TABLE_COLUMN_DESCRIPTOR_DUPLICATE", parameters = Map("type" -> "ADD", "columnName" -> "b", "optionName" -> "DEFAULT"), context = ExpectedContext( fragment = "b STRING DEFAULT \"123\" NOT NULL DEFAULT \"abc\"", @@ -2911,7 +3120,7 @@ class DDLParserSuite extends AnalysisTest { exception = intercept[ParseException]( parsePlan("ALTER TABLE my_tab ADD COLUMN b STRING COMMENT \"abc\" NOT NULL COMMENT \"abc\"") ), - errorClass = "ALTER_TABLE_COLUMN_DESCRIPTOR_DUPLICATE", + condition = "ALTER_TABLE_COLUMN_DESCRIPTOR_DUPLICATE", parameters = Map("type" -> "ADD", "columnName" -> "b", "optionName" -> "COMMENT"), context = ExpectedContext( fragment = "b STRING COMMENT \"abc\" NOT NULL COMMENT \"abc\"", @@ -2923,7 +3132,7 @@ class DDLParserSuite extends AnalysisTest { exception = intercept[ParseException]( parsePlan("ALTER TABLE my_tab ADD COLUMN b STRING FIRST COMMENT \"abc\" AFTER y") ), - errorClass = "ALTER_TABLE_COLUMN_DESCRIPTOR_DUPLICATE", + condition = "ALTER_TABLE_COLUMN_DESCRIPTOR_DUPLICATE", parameters = Map("type" -> "ADD", "columnName" -> "b", "optionName" -> "FIRST|AFTER"), context = ExpectedContext(fragment = "b STRING FIRST COMMENT \"abc\" AFTER y", start = 30, stop = 65) @@ -2935,7 +3144,7 @@ class DDLParserSuite extends AnalysisTest { "USING parquet CLUSTERED BY (a) INTO 2 BUCKETS CLUSTER BY (a)" checkError( exception = parseException(sql1), - errorClass = "SPECIFY_CLUSTER_BY_WITH_BUCKETING_IS_NOT_ALLOWED", + condition = "SPECIFY_CLUSTER_BY_WITH_BUCKETING_IS_NOT_ALLOWED", parameters = Map.empty, context = ExpectedContext(fragment = sql1, start = 0, stop = 96) ) @@ -2946,7 +3155,7 @@ class DDLParserSuite extends AnalysisTest { "USING parquet CLUSTERED BY (a) INTO 2 BUCKETS CLUSTER BY (a)" checkError( exception = parseException(sql1), - errorClass = "SPECIFY_CLUSTER_BY_WITH_BUCKETING_IS_NOT_ALLOWED", + condition = "SPECIFY_CLUSTER_BY_WITH_BUCKETING_IS_NOT_ALLOWED", parameters = Map.empty, context = ExpectedContext(fragment = sql1, start = 0, stop = 97) ) @@ -2957,7 +3166,7 @@ class DDLParserSuite extends AnalysisTest { "USING parquet CLUSTER BY (a) PARTITIONED BY (a)" checkError( exception = parseException(sql1), - errorClass = "SPECIFY_CLUSTER_BY_WITH_PARTITIONED_BY_IS_NOT_ALLOWED", + condition = "SPECIFY_CLUSTER_BY_WITH_PARTITIONED_BY_IS_NOT_ALLOWED", parameters = Map.empty, context = ExpectedContext(fragment = sql1, start = 0, stop = 83) ) @@ -2968,7 +3177,7 @@ class DDLParserSuite extends AnalysisTest { "USING parquet CLUSTER BY (a) PARTITIONED BY (a)" checkError( exception = parseException(sql1), - errorClass = "SPECIFY_CLUSTER_BY_WITH_PARTITIONED_BY_IS_NOT_ALLOWED", + condition = "SPECIFY_CLUSTER_BY_WITH_PARTITIONED_BY_IS_NOT_ALLOWED", parameters = Map.empty, context = ExpectedContext(fragment = sql1, start = 0, stop = 84) ) @@ -2986,7 +3195,7 @@ class DDLParserSuite extends AnalysisTest { checkError( exception = internalException(insertDirSql), - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "INSERT OVERWRITE DIRECTORY is not supported.")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index f11e920e4c07d..c416d21cfd4b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -138,12 +138,12 @@ class DataTypeParserSuite extends SparkFunSuite with SQLHelper { test("Do not print empty parentheses for no params") { checkError( exception = intercept("unknown"), - errorClass = "UNSUPPORTED_DATATYPE", + condition = "UNSUPPORTED_DATATYPE", parameters = Map("typeName" -> "\"UNKNOWN\"") ) checkError( exception = intercept("unknown(1,2,3)"), - errorClass = "UNSUPPORTED_DATATYPE", + condition = "UNSUPPORTED_DATATYPE", parameters = Map("typeName" -> "\"UNKNOWN(1,2,3)\"") ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala index e4f9b54680dc7..603d5d779769d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala @@ -34,7 +34,7 @@ class ErrorParserSuite extends AnalysisTest { test("semantic errors") { checkError( exception = parseException("select *\nfrom r\norder by q\ncluster by q"), - errorClass = "UNSUPPORTED_FEATURE.COMBINATION_QUERY_RESULT_CLAUSES", + condition = "UNSUPPORTED_FEATURE.COMBINATION_QUERY_RESULT_CLAUSES", parameters = Map.empty, context = ExpectedContext(fragment = "order by q\ncluster by q", start = 16, stop = 38)) } @@ -43,42 +43,42 @@ class ErrorParserSuite extends AnalysisTest { // scalastyle:off checkError( exception = parseException("USE \u0196pfel"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "\u0196pfel")) checkError( exception = parseException("USE \u88681"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "\u88681")) // scalastyle:on checkError( exception = parseException("USE https://www.spa.rk/bucket/pa-th.json?=&#%"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "https://www.spa.rk/bucket/pa-th.json?=&#%")) } test("hyphen in identifier - DDL tests") { checkError( exception = parseException("USE test-test"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "test-test")) checkError( exception = parseException("SET CATALOG test-test"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "test-test")) checkError( exception = parseException("CREATE DATABASE IF NOT EXISTS my-database"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "my-database")) checkError( exception = parseException( """ |ALTER DATABASE my-database |SET DBPROPERTIES ('p1'='v1')""".stripMargin), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "my-database")) checkError( exception = parseException("DROP DATABASE my-database"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "my-database")) checkError( exception = parseException( @@ -87,7 +87,7 @@ class ErrorParserSuite extends AnalysisTest { |CHANGE COLUMN |test-col TYPE BIGINT """.stripMargin), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "test-col")) checkError( exception = parseException( @@ -96,23 +96,23 @@ class ErrorParserSuite extends AnalysisTest { |DROP COLUMN |test-col, test """.stripMargin), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "test-col")) checkError( exception = parseException("CREATE TABLE test (attri-bute INT)"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "attri-bute")) checkError( exception = parseException("CREATE FUNCTION test-func as org.test.func"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "test-func")) checkError( exception = parseException("DROP FUNCTION test-func as org.test.func"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "test-func")) checkError( exception = parseException("SHOW FUNCTIONS LIKE test-func"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "test-func")) checkError( exception = parseException( @@ -123,7 +123,7 @@ class ErrorParserSuite extends AnalysisTest { |LOCATION '/user/external/page_view' |TBLPROPERTIES ('p1'='v1', 'p2'='v2') |AS SELECT * FROM src""".stripMargin), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "page-view")) checkError( exception = parseException( @@ -131,31 +131,31 @@ class ErrorParserSuite extends AnalysisTest { |CREATE TABLE IF NOT EXISTS tab |USING test-provider |AS SELECT * FROM src""".stripMargin), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "test-provider")) checkError( exception = parseException("SHOW TABLES IN hyphen-database"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "hyphen-database")) checkError( exception = parseException("SHOW TABLE EXTENDED IN hyphen-db LIKE \"str\""), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "hyphen-db")) checkError( exception = parseException("DESC SCHEMA EXTENDED test-db"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "test-db")) checkError( exception = parseException("ANALYZE TABLE test-table PARTITION (part1)"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "test-table")) checkError( exception = parseException("CREATE TABLE t(c1 struct)"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "test-test")) checkError( exception = parseException("LOAD DATA INPATH \"path\" INTO TABLE my-tab"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "my-tab")) } @@ -163,28 +163,28 @@ class ErrorParserSuite extends AnalysisTest { // dml tests checkError( exception = parseException("SELECT * FROM table-with-hyphen"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "table-with-hyphen")) // special test case: minus in expression shouldn't be treated as hyphen in identifiers checkError( exception = parseException("SELECT a-b FROM table-with-hyphen"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "table-with-hyphen")) checkError( exception = parseException("SELECT a-b AS a-b FROM t"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "a-b")) checkError( exception = parseException("SELECT a-b FROM table-hyphen WHERE a-b = 0"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "table-hyphen")) checkError( exception = parseException("SELECT (a - test_func(b-c)) FROM test-table"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "test-table")) checkError( exception = parseException("WITH a-b AS (SELECT 1 FROM s) SELECT * FROM s;"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "a-b")) checkError( exception = parseException( @@ -193,7 +193,7 @@ class ErrorParserSuite extends AnalysisTest { |FROM t1 JOIN t2 |USING (a, b, at-tr) """.stripMargin), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "at-tr")) checkError( exception = parseException( @@ -202,7 +202,7 @@ class ErrorParserSuite extends AnalysisTest { |OVER (PARTITION BY category ORDER BY revenue DESC) as hyphen-rank |FROM productRevenue """.stripMargin), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "hyphen-rank")) checkError( exception = parseException( @@ -213,7 +213,7 @@ class ErrorParserSuite extends AnalysisTest { |GROUP BY fake-breaker |ORDER BY c """.stripMargin), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "grammar-breaker")) assert(parsePlan( """ @@ -234,7 +234,7 @@ class ErrorParserSuite extends AnalysisTest { |WINDOW hyphen-window AS | (PARTITION BY a, b ORDER BY c rows BETWEEN 1 PRECEDING AND 1 FOLLOWING) """.stripMargin), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "hyphen-window")) checkError( exception = parseException( @@ -242,7 +242,7 @@ class ErrorParserSuite extends AnalysisTest { |SELECT * FROM tab |WINDOW window_ref AS window-ref """.stripMargin), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "window-ref")) checkError( exception = parseException( @@ -251,7 +251,7 @@ class ErrorParserSuite extends AnalysisTest { |FROM t-a INNER JOIN tb |ON ta.a = tb.a AND ta.tag = tb.tag """.stripMargin), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "t-a")) checkError( exception = parseException( @@ -260,7 +260,7 @@ class ErrorParserSuite extends AnalysisTest { |SELECT a |SELECT b """.stripMargin), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "test-table")) checkError( exception = parseException( @@ -273,7 +273,7 @@ class ErrorParserSuite extends AnalysisTest { | FOR test-test IN ('dotNET', 'Java') |); """.stripMargin), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "test-test")) } @@ -281,23 +281,23 @@ class ErrorParserSuite extends AnalysisTest { // general bad types checkError( exception = parseException("SELECT cast(1 as badtype)"), - errorClass = "UNSUPPORTED_DATATYPE", + condition = "UNSUPPORTED_DATATYPE", parameters = Map("typeName" -> "\"BADTYPE\""), context = ExpectedContext(fragment = "badtype", start = 17, stop = 23)) // special handling on char and varchar checkError( exception = parseException("SELECT cast('a' as CHAR)"), - errorClass = "DATATYPE_MISSING_SIZE", + condition = "DATATYPE_MISSING_SIZE", parameters = Map("type" -> "\"CHAR\""), context = ExpectedContext(fragment = "CHAR", start = 19, stop = 22)) checkError( exception = parseException("SELECT cast('a' as Varchar)"), - errorClass = "DATATYPE_MISSING_SIZE", + condition = "DATATYPE_MISSING_SIZE", parameters = Map("type" -> "\"VARCHAR\""), context = ExpectedContext(fragment = "Varchar", start = 19, stop = 25)) checkError( exception = parseException("SELECT cast('a' as Character)"), - errorClass = "DATATYPE_MISSING_SIZE", + condition = "DATATYPE_MISSING_SIZE", parameters = Map("type" -> "\"CHARACTER\""), context = ExpectedContext(fragment = "Character", start = 19, stop = 27)) } @@ -305,32 +305,32 @@ class ErrorParserSuite extends AnalysisTest { test("'!' where only NOT should be allowed") { checkError( exception = parseException("SELECT 1 ! IN (2)"), - errorClass = "SYNTAX_DISCONTINUED.BANG_EQUALS_NOT", + condition = "SYNTAX_DISCONTINUED.BANG_EQUALS_NOT", parameters = Map("clause" -> "!"), context = ExpectedContext(fragment = "!", start = 9, stop = 9)) checkError( exception = parseException("SELECT 'a' ! LIKE 'b'"), - errorClass = "SYNTAX_DISCONTINUED.BANG_EQUALS_NOT", + condition = "SYNTAX_DISCONTINUED.BANG_EQUALS_NOT", parameters = Map("clause" -> "!"), context = ExpectedContext(fragment = "!", start = 11, stop = 11)) checkError( exception = parseException("SELECT 1 ! BETWEEN 1 AND 2"), - errorClass = "SYNTAX_DISCONTINUED.BANG_EQUALS_NOT", + condition = "SYNTAX_DISCONTINUED.BANG_EQUALS_NOT", parameters = Map("clause" -> "!"), context = ExpectedContext(fragment = "!", start = 9, stop = 9)) checkError( exception = parseException("SELECT 1 IS ! NULL"), - errorClass = "SYNTAX_DISCONTINUED.BANG_EQUALS_NOT", + condition = "SYNTAX_DISCONTINUED.BANG_EQUALS_NOT", parameters = Map("clause" -> "!"), context = ExpectedContext(fragment = "!", start = 12, stop = 12)) checkError( exception = parseException("CREATE TABLE IF ! EXISTS t(c1 INT)"), - errorClass = "SYNTAX_DISCONTINUED.BANG_EQUALS_NOT", + condition = "SYNTAX_DISCONTINUED.BANG_EQUALS_NOT", parameters = Map("clause" -> "!"), context = ExpectedContext(fragment = "!", start = 16, stop = 16)) checkError( exception = parseException("CREATE TABLE t(c1 INT ! NULL)"), - errorClass = "SYNTAX_DISCONTINUED.BANG_EQUALS_NOT", + condition = "SYNTAX_DISCONTINUED.BANG_EQUALS_NOT", parameters = Map("clause" -> "!"), context = ExpectedContext(fragment = "!", start = 22, stop = 22)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 2654757177ee7..6d307d1cd9a87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -201,7 +201,7 @@ class ExpressionParserSuite extends AnalysisTest { checkError( exception = parseException("a like 'pattern%' escape '##'"), - errorClass = "INVALID_ESC", + condition = "INVALID_ESC", parameters = Map("invalidEscape" -> "'##'"), context = ExpectedContext( fragment = "like 'pattern%' escape '##'", @@ -210,7 +210,7 @@ class ExpressionParserSuite extends AnalysisTest { checkError( exception = parseException("a like 'pattern%' escape ''"), - errorClass = "INVALID_ESC", + condition = "INVALID_ESC", parameters = Map("invalidEscape" -> "''"), context = ExpectedContext( fragment = "like 'pattern%' escape ''", @@ -222,7 +222,7 @@ class ExpressionParserSuite extends AnalysisTest { checkError( exception = parseException("a not like 'pattern%' escape '\"/'"), - errorClass = "INVALID_ESC", + condition = "INVALID_ESC", parameters = Map("invalidEscape" -> "'\"/'"), context = ExpectedContext( fragment = "not like 'pattern%' escape '\"/'", @@ -231,7 +231,7 @@ class ExpressionParserSuite extends AnalysisTest { checkError( exception = parseException("a not like 'pattern%' escape ''"), - errorClass = "INVALID_ESC", + condition = "INVALID_ESC", parameters = Map("invalidEscape" -> "''"), context = ExpectedContext( fragment = "not like 'pattern%' escape ''", @@ -261,7 +261,7 @@ class ExpressionParserSuite extends AnalysisTest { Seq("any", "some", "all").foreach { quantifier => checkError( exception = parseException(s"a like $quantifier()"), - errorClass = "_LEGACY_ERROR_TEMP_0064", + condition = "_LEGACY_ERROR_TEMP_0064", parameters = Map("msg" -> "Expected something between '(' and ')'."), context = ExpectedContext( fragment = s"like $quantifier()", @@ -328,7 +328,7 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("`select`(all a, b)", $"select".function($"a", $"b")) checkError( exception = parseException("foo(a x)"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'x'", "hint" -> ": extra input 'x'")) } @@ -461,7 +461,7 @@ class ExpressionParserSuite extends AnalysisTest { // We cannot use an arbitrary expression. checkError( exception = parseException("foo(*) over (partition by a order by b rows exp(b) preceding)"), - errorClass = "_LEGACY_ERROR_TEMP_0064", + condition = "_LEGACY_ERROR_TEMP_0064", parameters = Map("msg" -> "Frame bound value must be a literal."), context = ExpectedContext( fragment = "exp(b) preceding", @@ -540,7 +540,7 @@ class ExpressionParserSuite extends AnalysisTest { Literal(Timestamp.valueOf("2016-03-11 20:54:00.000"))) checkError( exception = parseException("timestamP_LTZ '2016-33-11 20:54:00.000'"), - errorClass = "INVALID_TYPED_LITERAL", + condition = "INVALID_TYPED_LITERAL", sqlState = "42604", parameters = Map( "valueType" -> "\"TIMESTAMP_LTZ\"", @@ -556,7 +556,7 @@ class ExpressionParserSuite extends AnalysisTest { Literal(LocalDateTime.parse("2016-03-11T20:54:00.000"))) checkError( exception = parseException("tImEstAmp_Ntz '2016-33-11 20:54:00.000'"), - errorClass = "INVALID_TYPED_LITERAL", + condition = "INVALID_TYPED_LITERAL", sqlState = "42604", parameters = Map( "valueType" -> "\"TIMESTAMP_NTZ\"", @@ -572,7 +572,7 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("dAte '2016-03-11'", Literal(Date.valueOf("2016-03-11"))) checkError( exception = parseException("DAtE 'mar 11 2016'"), - errorClass = "INVALID_TYPED_LITERAL", + condition = "INVALID_TYPED_LITERAL", sqlState = "42604", parameters = Map("valueType" -> "\"DATE\"", "value" -> "'mar 11 2016'"), context = ExpectedContext( @@ -585,7 +585,7 @@ class ExpressionParserSuite extends AnalysisTest { Literal(Timestamp.valueOf("2016-03-11 20:54:00.000"))) checkError( exception = parseException("timestamP '2016-33-11 20:54:00.000'"), - errorClass = "INVALID_TYPED_LITERAL", + condition = "INVALID_TYPED_LITERAL", sqlState = "42604", parameters = Map("valueType" -> "\"TIMESTAMP\"", "value" -> "'2016-33-11 20:54:00.000'"), context = ExpectedContext( @@ -600,7 +600,7 @@ class ExpressionParserSuite extends AnalysisTest { checkError( exception = parseException("timestamP '2016-33-11 20:54:00.000'"), - errorClass = "INVALID_TYPED_LITERAL", + condition = "INVALID_TYPED_LITERAL", sqlState = "42604", parameters = Map("valueType" -> "\"TIMESTAMP\"", "value" -> "'2016-33-11 20:54:00.000'"), context = ExpectedContext( @@ -621,7 +621,7 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("INTERVAL '1 year 2 month'", ymIntervalLiteral) checkError( exception = parseException("Interval 'interval 1 yearsss 2 monthsss'"), - errorClass = "INVALID_TYPED_LITERAL", + condition = "INVALID_TYPED_LITERAL", parameters = Map( "valueType" -> "\"INTERVAL\"", "value" -> "'interval 1 yearsss 2 monthsss'" @@ -638,7 +638,7 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("INTERVAL '1 day 2 hour 3 minute 4.005006 second'", dtIntervalLiteral) checkError( exception = parseException("Interval 'interval 1 daysss 2 hoursss'"), - errorClass = "INVALID_TYPED_LITERAL", + condition = "INVALID_TYPED_LITERAL", parameters = Map( "valueType" -> "\"INTERVAL\"", "value" -> "'interval 1 daysss 2 hoursss'" @@ -651,7 +651,7 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("-interval '1 day 2 hour 3 minute 4.005006 second'", UnaryMinus(dtIntervalLiteral)) checkError( exception = parseException("INTERVAL '1 year 2 second'"), - errorClass = "_LEGACY_ERROR_TEMP_0029", + condition = "_LEGACY_ERROR_TEMP_0029", parameters = Map("literal" -> "INTERVAL '1 year 2 second'"), context = ExpectedContext( fragment = "INTERVAL '1 year 2 second'", @@ -664,7 +664,7 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("INTERVAL '3 month 1 hour'", intervalLiteral) checkError( exception = parseException("Interval 'interval 3 monthsss 1 hoursss'"), - errorClass = "INVALID_TYPED_LITERAL", + condition = "INVALID_TYPED_LITERAL", parameters = Map( "valueType" -> "\"INTERVAL\"", "value" -> "'interval 3 monthsss 1 hoursss'" @@ -688,7 +688,7 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("x'A10C'", Literal(Array(0xa1, 0x0c).map(_.toByte))) checkError( exception = parseException("x'A1OC'"), - errorClass = "INVALID_TYPED_LITERAL", + condition = "INVALID_TYPED_LITERAL", sqlState = "42604", parameters = Map( "valueType" -> "\"X\"", @@ -701,7 +701,7 @@ class ExpressionParserSuite extends AnalysisTest { checkError( exception = parseException("GEO '(10,-6)'"), - errorClass = "UNSUPPORTED_TYPED_LITERAL", + condition = "UNSUPPORTED_TYPED_LITERAL", parameters = Map( "unsupportedType" -> "\"GEO\"", "supportedTypes" -> @@ -743,14 +743,14 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("9.e+1BD", Literal(BigDecimal("9.e+1").underlying())) checkError( exception = parseException(".e3"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'.'", "hint" -> "")) // Tiny Int Literal assertEqual("10Y", Literal(10.toByte)) checkError( exception = parseException("1000Y"), - errorClass = "INVALID_NUMERIC_LITERAL_RANGE", + condition = "INVALID_NUMERIC_LITERAL_RANGE", parameters = Map( "rawStrippedQualifier" -> "1000", "minValue" -> Byte.MinValue.toString, @@ -765,7 +765,7 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("10S", Literal(10.toShort)) checkError( exception = parseException("40000S"), - errorClass = "INVALID_NUMERIC_LITERAL_RANGE", + condition = "INVALID_NUMERIC_LITERAL_RANGE", parameters = Map( "rawStrippedQualifier" -> "40000", "minValue" -> Short.MinValue.toString, @@ -780,7 +780,7 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("10L", Literal(10L)) checkError( exception = parseException("78732472347982492793712334L"), - errorClass = "INVALID_NUMERIC_LITERAL_RANGE", + condition = "INVALID_NUMERIC_LITERAL_RANGE", parameters = Map( "rawStrippedQualifier" -> "78732472347982492793712334", "minValue" -> Long.MinValue.toString, @@ -795,7 +795,7 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("10.0D", Literal(10.0D)) checkError( exception = parseException("-1.8E308D"), - errorClass = "INVALID_NUMERIC_LITERAL_RANGE", + condition = "INVALID_NUMERIC_LITERAL_RANGE", parameters = Map( "rawStrippedQualifier" -> "-1.8E308", "minValue" -> BigDecimal(Double.MinValue).toString, @@ -807,7 +807,7 @@ class ExpressionParserSuite extends AnalysisTest { stop = 8)) checkError( exception = parseException("1.8E308D"), - errorClass = "INVALID_NUMERIC_LITERAL_RANGE", + condition = "INVALID_NUMERIC_LITERAL_RANGE", parameters = Map( "rawStrippedQualifier" -> "1.8E308", "minValue" -> BigDecimal(Double.MinValue).toString, @@ -825,7 +825,7 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("123.08BD", Literal(BigDecimal("123.08").underlying())) checkError( exception = parseException("1.20E-38BD"), - errorClass = "DECIMAL_PRECISION_EXCEEDS_MAX_PRECISION", + condition = "DECIMAL_PRECISION_EXCEEDS_MAX_PRECISION", parameters = Map( "precision" -> "40", "maxPrecision" -> "38"), @@ -899,7 +899,7 @@ class ExpressionParserSuite extends AnalysisTest { // when ESCAPED_STRING_LITERALS is enabled. checkError( exception = parseException("'\''"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'''", "hint" -> ": extra input '''")) // The unescape special characters (e.g., "\\t") for 2.0+ don't work @@ -1082,7 +1082,7 @@ class ExpressionParserSuite extends AnalysisTest { // Unknown FROM TO intervals checkError( exception = parseException("interval '10' month to second"), - errorClass = "_LEGACY_ERROR_TEMP_0028", + condition = "_LEGACY_ERROR_TEMP_0028", parameters = Map("from" -> "month", "to" -> "second"), context = ExpectedContext( fragment = "'10' month to second", @@ -1104,7 +1104,7 @@ class ExpressionParserSuite extends AnalysisTest { } else { checkError( exception = parseException(s"interval $intervalStr"), - errorClass = "_LEGACY_ERROR_TEMP_0029", + condition = "_LEGACY_ERROR_TEMP_0029", parameters = Map("literal" -> "interval 3 monThs 4 dayS 22 sEcond 1 millisecond"), context = ExpectedContext( fragment = s"interval $intervalStr", @@ -1120,7 +1120,7 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("1 - f('o', o(bar))", Literal(1) - $"f".function("o", $"o".function($"bar"))) checkError( exception = parseException("1 - f('o', o(bar)) hello * world"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'*'", "hint" -> "")) } @@ -1142,7 +1142,7 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual(complexName.quotedString, UnresolvedAttribute(Seq("`fo`o", "`ba`r"))) checkError( exception = parseException(complexName.unquotedString), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'.'", "hint" -> "")) // Function identifier contains continuous backticks should be treated correctly. @@ -1225,7 +1225,7 @@ class ExpressionParserSuite extends AnalysisTest { Seq("any", "some", "all").foreach { quantifier => checkError( exception = parseException(s"a ilike $quantifier()"), - errorClass = "_LEGACY_ERROR_TEMP_0064", + condition = "_LEGACY_ERROR_TEMP_0064", parameters = Map("msg" -> "Expected something between '(' and ')'."), context = ExpectedContext( fragment = s"ilike $quantifier()", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala index 83d2557108c57..93afef60a9ddf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala @@ -159,7 +159,7 @@ class ParserUtilsSuite extends SparkFunSuite { exception = intercept[ParseException] { operationNotAllowed(errorMessage, showFuncContext) }, - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> errorMessage)) } @@ -172,7 +172,7 @@ class ParserUtilsSuite extends SparkFunSuite { exception = intercept[ParseException] { checkDuplicateKeys(properties2, createDbContext) }, - errorClass = "DUPLICATE_KEY", + condition = "DUPLICATE_KEY", parameters = Map("keyColumn" -> "`a`")) } @@ -223,7 +223,7 @@ class ParserUtilsSuite extends SparkFunSuite { exception = intercept[ParseException] { validate(f1(emptyContext), message, emptyContext) }, - errorClass = "_LEGACY_ERROR_TEMP_0064", + condition = "_LEGACY_ERROR_TEMP_0064", parameters = Map("msg" -> message)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index a6a32e87b7421..6901f6e928c8a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -204,7 +204,7 @@ class PlanParserSuite extends AnalysisTest { |""".stripMargin checkError( exception = parseException(query), - errorClass = "UNCLOSED_BRACKETED_COMMENT", + condition = "UNCLOSED_BRACKETED_COMMENT", parameters = Map.empty) } @@ -222,7 +222,7 @@ class PlanParserSuite extends AnalysisTest { |""".stripMargin checkError( exception = parseException(query), - errorClass = "UNCLOSED_BRACKETED_COMMENT", + condition = "UNCLOSED_BRACKETED_COMMENT", parameters = Map.empty) } @@ -237,7 +237,7 @@ class PlanParserSuite extends AnalysisTest { val sql1 = "EXPLAIN logical SELECT 1" checkError( exception = parseException(sql1), - errorClass = "_LEGACY_ERROR_TEMP_0039", + condition = "_LEGACY_ERROR_TEMP_0039", parameters = Map.empty, context = ExpectedContext( fragment = sql1, @@ -247,7 +247,7 @@ class PlanParserSuite extends AnalysisTest { val sql2 = "EXPLAIN formatted SELECT 1" checkError( exception = parseException(sql2), - errorClass = "_LEGACY_ERROR_TEMP_0039", + condition = "_LEGACY_ERROR_TEMP_0039", parameters = Map.empty, context = ExpectedContext( fragment = sql2, @@ -295,7 +295,7 @@ class PlanParserSuite extends AnalysisTest { val sql = "with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1" checkError( exception = parseException(sql), - errorClass = "_LEGACY_ERROR_TEMP_0038", + condition = "_LEGACY_ERROR_TEMP_0038", parameters = Map("duplicateNames" -> "'cte1'"), context = ExpectedContext( fragment = sql, @@ -328,13 +328,13 @@ class PlanParserSuite extends AnalysisTest { val sql1 = "from a" checkError( exception = parseException(sql1), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "end of input", "hint" -> "")) val sql2 = "from (from a union all from b) c select *" checkError( exception = parseException(sql2), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'union'", "hint" -> "")) } @@ -345,12 +345,12 @@ class PlanParserSuite extends AnalysisTest { val sql1 = "from a select * select * from x where a.s < 10" checkError( exception = parseException(sql1), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'from'", "hint" -> "")) val sql2 = "from a select * from b" checkError( exception = parseException(sql2), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'from'", "hint" -> "")) assertEqual( "from a insert into tbl1 select * insert into tbl2 select * where s < 10", @@ -393,7 +393,7 @@ class PlanParserSuite extends AnalysisTest { val sql1 = s"$baseSql order by a sort by a" checkError( exception = parseException(sql1), - errorClass = "UNSUPPORTED_FEATURE.COMBINATION_QUERY_RESULT_CLAUSES", + condition = "UNSUPPORTED_FEATURE.COMBINATION_QUERY_RESULT_CLAUSES", parameters = Map.empty, context = ExpectedContext( fragment = "order by a sort by a", @@ -403,7 +403,7 @@ class PlanParserSuite extends AnalysisTest { val sql2 = s"$baseSql cluster by a distribute by a" checkError( exception = parseException(sql2), - errorClass = "UNSUPPORTED_FEATURE.COMBINATION_QUERY_RESULT_CLAUSES", + condition = "UNSUPPORTED_FEATURE.COMBINATION_QUERY_RESULT_CLAUSES", parameters = Map.empty, context = ExpectedContext( fragment = "cluster by a distribute by a", @@ -413,7 +413,7 @@ class PlanParserSuite extends AnalysisTest { val sql3 = s"$baseSql order by a cluster by a" checkError( exception = parseException(sql3), - errorClass = "UNSUPPORTED_FEATURE.COMBINATION_QUERY_RESULT_CLAUSES", + condition = "UNSUPPORTED_FEATURE.COMBINATION_QUERY_RESULT_CLAUSES", parameters = Map.empty, context = ExpectedContext( fragment = "order by a cluster by a", @@ -423,7 +423,7 @@ class PlanParserSuite extends AnalysisTest { val sql4 = s"$baseSql order by a distribute by a" checkError( exception = parseException(sql4), - errorClass = "UNSUPPORTED_FEATURE.COMBINATION_QUERY_RESULT_CLAUSES", + condition = "UNSUPPORTED_FEATURE.COMBINATION_QUERY_RESULT_CLAUSES", parameters = Map.empty, context = ExpectedContext( fragment = "order by a distribute by a", @@ -499,7 +499,7 @@ class PlanParserSuite extends AnalysisTest { val sql1 = "SELECT a, b, count(distinct a, distinct b) as c FROM d GROUP BY a, b" checkError( exception = parseException(sql1), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'b'", "hint" -> ": extra input 'b'")) } @@ -595,7 +595,7 @@ class PlanParserSuite extends AnalysisTest { |)""".stripMargin checkError( exception = parseException(sql1), - errorClass = "NOT_ALLOWED_IN_FROM.LATERAL_WITH_PIVOT", + condition = "NOT_ALLOWED_IN_FROM.LATERAL_WITH_PIVOT", parameters = Map.empty, context = ExpectedContext( fragment = fragment1, @@ -617,7 +617,7 @@ class PlanParserSuite extends AnalysisTest { |)""".stripMargin checkError( exception = parseException(sql2), - errorClass = "NOT_ALLOWED_IN_FROM.LATERAL_WITH_UNPIVOT", + condition = "NOT_ALLOWED_IN_FROM.LATERAL_WITH_UNPIVOT", parameters = Map.empty, context = ExpectedContext( fragment = fragment2, @@ -647,7 +647,7 @@ class PlanParserSuite extends AnalysisTest { |)""".stripMargin checkError( exception = parseException(sql3), - errorClass = "NOT_ALLOWED_IN_FROM.UNPIVOT_WITH_PIVOT", + condition = "NOT_ALLOWED_IN_FROM.UNPIVOT_WITH_PIVOT", parameters = Map.empty, context = ExpectedContext( fragment = fragment3, @@ -711,7 +711,7 @@ class PlanParserSuite extends AnalysisTest { val sql1 = "select * from a natural cross join b" checkError( exception = parseException(sql1), - errorClass = "INCOMPATIBLE_JOIN_TYPES", + condition = "INCOMPATIBLE_JOIN_TYPES", parameters = Map("joinType1" -> "NATURAL", "joinType2" -> "CROSS"), sqlState = "42613", context = ExpectedContext( @@ -723,7 +723,7 @@ class PlanParserSuite extends AnalysisTest { val sql2 = "select * from a natural join b on a.id = b.id" checkError( exception = parseException(sql2), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'on'", "hint" -> "")) // Test multiple consecutive joins @@ -744,7 +744,7 @@ class PlanParserSuite extends AnalysisTest { val sql3 = "select * from t1 inner join t2 inner join t3 on col3 = col2 on col3 = col1" checkError( exception = parseException(sql3), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'on'", "hint" -> "")) // Parenthesis @@ -834,7 +834,7 @@ class PlanParserSuite extends AnalysisTest { val fragment1 = "tablesample(bucket 4 out of 10 on x)" checkError( exception = parseException(sql1), - errorClass = "_LEGACY_ERROR_TEMP_0015", + condition = "_LEGACY_ERROR_TEMP_0015", parameters = Map("msg" -> "BUCKET x OUT OF y ON colname"), context = ExpectedContext( fragment = fragment1, @@ -845,7 +845,7 @@ class PlanParserSuite extends AnalysisTest { val fragment2 = "tablesample(bucket 11 out of 10)" checkError( exception = parseException(sql2), - errorClass = "_LEGACY_ERROR_TEMP_0064", + condition = "_LEGACY_ERROR_TEMP_0064", parameters = Map("msg" -> "Sampling fraction (1.1) must be on interval [0, 1]"), context = ExpectedContext( fragment = fragment2, @@ -856,7 +856,7 @@ class PlanParserSuite extends AnalysisTest { val fragment3 = "TABLESAMPLE(300M)" checkError( exception = parseException(sql3), - errorClass = "_LEGACY_ERROR_TEMP_0015", + condition = "_LEGACY_ERROR_TEMP_0015", parameters = Map("msg" -> "byteLengthLiteral"), context = ExpectedContext( fragment = fragment3, @@ -867,7 +867,7 @@ class PlanParserSuite extends AnalysisTest { val fragment4 = "TABLESAMPLE(BUCKET 3 OUT OF 32 ON rand())" checkError( exception = parseException(sql4), - errorClass = "_LEGACY_ERROR_TEMP_0015", + condition = "_LEGACY_ERROR_TEMP_0015", parameters = Map("msg" -> "BUCKET x OUT OF y ON function"), context = ExpectedContext( fragment = fragment4, @@ -925,7 +925,7 @@ class PlanParserSuite extends AnalysisTest { val fragment1 = "default.range(2)" checkError( exception = parseException(sql1), - errorClass = "INVALID_SQL_SYNTAX.INVALID_TABLE_VALUED_FUNC_NAME", + condition = "INVALID_SQL_SYNTAX.INVALID_TABLE_VALUED_FUNC_NAME", parameters = Map("funcName" -> "`default`.`range`"), context = ExpectedContext( fragment = fragment1, @@ -937,7 +937,7 @@ class PlanParserSuite extends AnalysisTest { val fragment2 = "spark_catalog.default.range(2)" checkError( exception = parseException(sql2), - errorClass = "INVALID_SQL_SYNTAX.INVALID_TABLE_VALUED_FUNC_NAME", + condition = "INVALID_SQL_SYNTAX.INVALID_TABLE_VALUED_FUNC_NAME", parameters = Map("funcName" -> "`spark_catalog`.`default`.`range`"), context = ExpectedContext( fragment = fragment2, @@ -1047,14 +1047,14 @@ class PlanParserSuite extends AnalysisTest { val sql1 = "SELECT /*+ HINT() */ * FROM t" checkError( exception = parseException(sql1), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "')'", "hint" -> "")) // Disallow space as the delimiter. val sql2 = "SELECT /*+ INDEX(a b c) */ * from default.t" checkError( exception = parseException(sql2), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'b'", "hint" -> "")) assertEqual( @@ -1114,7 +1114,7 @@ class PlanParserSuite extends AnalysisTest { val sql3 = "SELECT /*+ COALESCE(30 + 50) */ * FROM t" checkError( exception = parseException(sql3), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'+'", "hint" -> "")) assertEqual( @@ -1241,13 +1241,13 @@ class PlanParserSuite extends AnalysisTest { val sql1 = "select ltrim(both 'S' from 'SS abc S'" checkError( exception = parseException(sql1), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'from'", "hint" -> "")) // expecting {')' val sql2 = "select rtrim(trailing 'S' from 'SS abc S'" checkError( exception = parseException(sql2), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'from'", "hint" -> "")) // expecting {')' assertTrimPlans( @@ -1361,7 +1361,7 @@ class PlanParserSuite extends AnalysisTest { val sql1 = "CREATE VIEW testView AS INSERT INTO jt VALUES(1, 1)" checkError( exception = parseException(sql1), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'INSERT'", "hint" -> "")) // Multi insert query @@ -1371,13 +1371,13 @@ class PlanParserSuite extends AnalysisTest { |INSERT INTO tbl2 SELECT * WHERE jt.id > 4""".stripMargin checkError( exception = parseException(sql2), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'INSERT'", "hint" -> "")) val sql3 = "ALTER VIEW testView AS INSERT INTO jt VALUES(1, 1)" checkError( exception = parseException(sql3), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'INSERT'", "hint" -> "")) // Multi insert query @@ -1387,7 +1387,7 @@ class PlanParserSuite extends AnalysisTest { |INSERT INTO tbl2 SELECT * WHERE jt.id > 4""".stripMargin checkError( exception = parseException(sql4), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'INSERT'", "hint" -> "")) } @@ -1395,13 +1395,13 @@ class PlanParserSuite extends AnalysisTest { val sql1 = "SELECT * FROM (INSERT INTO BAR VALUES (2))" checkError( exception = parseException(sql1), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'BAR'", "hint" -> ": missing ')'")) val sql2 = "SELECT * FROM S WHERE C1 IN (INSERT INTO T VALUES (2))" checkError( exception = parseException(sql2), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'IN'", "hint" -> "")) } @@ -1506,7 +1506,7 @@ class PlanParserSuite extends AnalysisTest { val sql1 = "select * from my_tvf(arg1 => table v1)" checkError( exception = parseException(sql1), - errorClass = + condition = "INVALID_SQL_SYNTAX.INVALID_TABLE_FUNCTION_IDENTIFIER_ARGUMENT_MISSING_PARENTHESES", parameters = Map("argumentName" -> "`v1`"), context = ExpectedContext( @@ -1627,14 +1627,14 @@ class PlanParserSuite extends AnalysisTest { val sql6 = "select * from my_tvf(arg1 => table(1) partition by col1 with single partition)" checkError( exception = parseException(sql6), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map( "error" -> "'partition'", "hint" -> "")) val sql7 = "select * from my_tvf(arg1 => table(1) order by col1)" checkError( exception = parseException(sql7), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map( "error" -> "'order'", "hint" -> "")) @@ -1643,7 +1643,7 @@ class PlanParserSuite extends AnalysisTest { val sql8 = s"select * from my_tvf(arg1 => $sql8tableArg $sql8partition)" checkError( exception = parseException(sql8), - errorClass = "_LEGACY_ERROR_TEMP_0064", + condition = "_LEGACY_ERROR_TEMP_0064", parameters = Map( "msg" -> ("The table function call includes a table argument with an invalid " + @@ -1766,7 +1766,7 @@ class PlanParserSuite extends AnalysisTest { |FROM testData""".stripMargin checkError( exception = parseException(sql), - errorClass = "UNSUPPORTED_FEATURE.TRANSFORM_NON_HIVE", + condition = "UNSUPPORTED_FEATURE.TRANSFORM_NON_HIVE", parameters = Map.empty, context = ExpectedContext( fragment = sql, @@ -1824,7 +1824,7 @@ class PlanParserSuite extends AnalysisTest { val fragment = "TIMESTAMP AS OF col" checkError( exception = parseException(sql), - errorClass = "_LEGACY_ERROR_TEMP_0056", + condition = "_LEGACY_ERROR_TEMP_0056", parameters = Map("reason" -> "timestamp expression cannot refer to any columns"), context = ExpectedContext( fragment = fragment, @@ -1919,11 +1919,11 @@ class PlanParserSuite extends AnalysisTest { // Invalid empty name and invalid symbol in a name checkError( exception = parseException(s"SELECT :-"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'-'", "hint" -> "")) checkError( exception = parseException(s"SELECT :"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "end of input", "hint" -> "")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 5fc3ade408bd9..ba634333e06fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Expression, In, Literal, ScalarSubquery} import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.catalyst.plans.logical.CreateVariable +import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, Project} +import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.exceptions.SqlScriptingException class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { @@ -205,13 +207,14 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT a, b, c FROM T; | SELECT * FROM T; |END lbl_end""".stripMargin - + val exception = intercept[SqlScriptingException] { + parseScript(sqlScriptText) + } checkError( - exception = intercept[SqlScriptingException] { - parseScript(sqlScriptText) - }, - errorClass = "LABELS_MISMATCH", - parameters = Map("beginLabel" -> "lbl_begin", "endLabel" -> "lbl_end")) + exception = exception, + condition = "LABELS_MISMATCH", + parameters = Map("beginLabel" -> toSQLId("lbl_begin"), "endLabel" -> toSQLId("lbl_end"))) + assert(exception.origin.line.contains(2)) } test("compound: endLabel") { @@ -224,13 +227,14 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT a, b, c FROM T; | SELECT * FROM T; |END lbl""".stripMargin - + val exception = intercept[SqlScriptingException] { + parseScript(sqlScriptText) + } checkError( - exception = intercept[SqlScriptingException] { - parseScript(sqlScriptText) - }, - errorClass = "END_LABEL_WITHOUT_BEGIN_LABEL", - parameters = Map("endLabel" -> "lbl")) + exception = exception, + condition = "END_LABEL_WITHOUT_BEGIN_LABEL", + parameters = Map("endLabel" -> toSQLId("lbl"))) + assert(exception.origin.line.contains(8)) } test("compound: beginLabel + endLabel with different casing") { @@ -286,12 +290,14 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT 1; | DECLARE testVariable INTEGER; |END""".stripMargin + val exception = intercept[SqlScriptingException] { + parseScript(sqlScriptText) + } checkError( - exception = intercept[SqlScriptingException] { - parseScript(sqlScriptText) - }, - errorClass = "INVALID_VARIABLE_DECLARATION.ONLY_AT_BEGINNING", - parameters = Map("varName" -> "`testVariable`", "lineNumber" -> "4")) + exception = exception, + condition = "INVALID_VARIABLE_DECLARATION.ONLY_AT_BEGINNING", + parameters = Map("varName" -> "`testVariable`")) + assert(exception.origin.line.contains(4)) } test("declare in wrong scope") { @@ -302,12 +308,14 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | DECLARE testVariable INTEGER; | END IF; |END""".stripMargin + val exception = intercept[SqlScriptingException] { + parseScript(sqlScriptText) + } checkError( - exception = intercept[SqlScriptingException] { - parseScript(sqlScriptText) - }, - errorClass = "INVALID_VARIABLE_DECLARATION.NOT_ALLOWED_IN_SCOPE", - parameters = Map("varName" -> "`testVariable`", "lineNumber" -> "4")) + exception = exception, + condition = "INVALID_VARIABLE_DECLARATION.NOT_ALLOWED_IN_SCOPE", + parameters = Map("varName" -> "`testVariable`")) + assert(exception.origin.line.contains(4)) } test("SET VAR statement test") { @@ -666,7 +674,730 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { head.asInstanceOf[SingleStatement].getText == "SELECT 42") assert(whileStmt.label.contains("lbl")) + } + + test("leave compound block") { + val sqlScriptText = + """ + |lbl: BEGIN + | SELECT 1; + | LEAVE lbl; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 2) + assert(tree.collection.head.isInstanceOf[SingleStatement]) + assert(tree.collection(1).isInstanceOf[LeaveStatement]) + } + + test("leave while loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: WHILE 1 = 1 DO + | SELECT 1; + | LEAVE lbl; + | END WHILE; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[WhileStatement]) + + val whileStmt = tree.collection.head.asInstanceOf[WhileStatement] + assert(whileStmt.condition.isInstanceOf[SingleStatement]) + assert(whileStmt.condition.getText == "1 = 1") + + assert(whileStmt.body.isInstanceOf[CompoundBody]) + assert(whileStmt.body.collection.length == 2) + + assert(whileStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(whileStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(whileStmt.body.collection(1).isInstanceOf[LeaveStatement]) + assert(whileStmt.body.collection(1).asInstanceOf[LeaveStatement].label == "lbl") + } + + test("leave repeat loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: REPEAT + | SELECT 1; + | LEAVE lbl; + | UNTIL 1 = 2 + | END REPEAT; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[RepeatStatement]) + + val repeatStmt = tree.collection.head.asInstanceOf[RepeatStatement] + assert(repeatStmt.condition.isInstanceOf[SingleStatement]) + assert(repeatStmt.condition.getText == "1 = 2") + + assert(repeatStmt.body.isInstanceOf[CompoundBody]) + assert(repeatStmt.body.collection.length == 2) + + assert(repeatStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(repeatStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(repeatStmt.body.collection(1).isInstanceOf[LeaveStatement]) + assert(repeatStmt.body.collection(1).asInstanceOf[LeaveStatement].label == "lbl") + } + + test ("iterate compound block - should fail") { + val sqlScriptText = + """ + |lbl: BEGIN + | SELECT 1; + | ITERATE lbl; + |END""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + parseScript(sqlScriptText) + }, + condition = "INVALID_LABEL_USAGE.ITERATE_IN_COMPOUND", + parameters = Map("labelName" -> "LBL")) + } + + test("iterate while loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: WHILE 1 = 1 DO + | SELECT 1; + | ITERATE lbl; + | END WHILE; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[WhileStatement]) + + val whileStmt = tree.collection.head.asInstanceOf[WhileStatement] + assert(whileStmt.condition.isInstanceOf[SingleStatement]) + assert(whileStmt.condition.getText == "1 = 1") + + assert(whileStmt.body.isInstanceOf[CompoundBody]) + assert(whileStmt.body.collection.length == 2) + + assert(whileStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(whileStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(whileStmt.body.collection(1).isInstanceOf[IterateStatement]) + assert(whileStmt.body.collection(1).asInstanceOf[IterateStatement].label == "lbl") + } + + test("iterate repeat loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: REPEAT + | SELECT 1; + | ITERATE lbl; + | UNTIL 1 = 2 + | END REPEAT; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[RepeatStatement]) + + val repeatStmt = tree.collection.head.asInstanceOf[RepeatStatement] + assert(repeatStmt.condition.isInstanceOf[SingleStatement]) + assert(repeatStmt.condition.getText == "1 = 2") + + assert(repeatStmt.body.isInstanceOf[CompoundBody]) + assert(repeatStmt.body.collection.length == 2) + + assert(repeatStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(repeatStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(repeatStmt.body.collection(1).isInstanceOf[IterateStatement]) + assert(repeatStmt.body.collection(1).asInstanceOf[IterateStatement].label == "lbl") + } + + test("leave with wrong label - should fail") { + val sqlScriptText = + """ + |lbl: BEGIN + | SELECT 1; + | LEAVE randomlbl; + |END""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + parseScript(sqlScriptText) + }, + condition = "INVALID_LABEL_USAGE.DOES_NOT_EXIST", + parameters = Map("labelName" -> "RANDOMLBL", "statementType" -> "LEAVE")) + } + + test("iterate with wrong label - should fail") { + val sqlScriptText = + """ + |lbl: BEGIN + | SELECT 1; + | ITERATE randomlbl; + |END""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + parseScript(sqlScriptText) + }, + condition = "INVALID_LABEL_USAGE.DOES_NOT_EXIST", + parameters = Map("labelName" -> "RANDOMLBL", "statementType" -> "ITERATE")) + } + + test("leave outer loop from nested while loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: WHILE 1 = 1 DO + | lbl2: WHILE 2 = 2 DO + | SELECT 1; + | LEAVE lbl; + | END WHILE; + | END WHILE; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[WhileStatement]) + + val whileStmt = tree.collection.head.asInstanceOf[WhileStatement] + assert(whileStmt.condition.isInstanceOf[SingleStatement]) + assert(whileStmt.condition.getText == "1 = 1") + + assert(whileStmt.body.isInstanceOf[CompoundBody]) + assert(whileStmt.body.collection.length == 1) + + val nestedWhileStmt = whileStmt.body.collection.head.asInstanceOf[WhileStatement] + assert(nestedWhileStmt.condition.isInstanceOf[SingleStatement]) + assert(nestedWhileStmt.condition.getText == "2 = 2") + + assert(nestedWhileStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(nestedWhileStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(nestedWhileStmt.body.collection(1).isInstanceOf[LeaveStatement]) + assert(nestedWhileStmt.body.collection(1).asInstanceOf[LeaveStatement].label == "lbl") + } + + test("leave outer loop from nested repeat loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: REPEAT + | lbl2: REPEAT + | SELECT 1; + | LEAVE lbl; + | UNTIL 2 = 2 + | END REPEAT; + | UNTIL 1 = 1 + | END REPEAT; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[RepeatStatement]) + + val repeatStmt = tree.collection.head.asInstanceOf[RepeatStatement] + assert(repeatStmt.condition.isInstanceOf[SingleStatement]) + assert(repeatStmt.condition.getText == "1 = 1") + + assert(repeatStmt.body.isInstanceOf[CompoundBody]) + assert(repeatStmt.body.collection.length == 1) + + val nestedRepeatStmt = repeatStmt.body.collection.head.asInstanceOf[RepeatStatement] + assert(nestedRepeatStmt.condition.isInstanceOf[SingleStatement]) + assert(nestedRepeatStmt.condition.getText == "2 = 2") + + assert(nestedRepeatStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert( + nestedRepeatStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(nestedRepeatStmt.body.collection(1).isInstanceOf[LeaveStatement]) + assert(nestedRepeatStmt.body.collection(1).asInstanceOf[LeaveStatement].label == "lbl") + } + + test("iterate outer loop from nested while loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: WHILE 1 = 1 DO + | lbl2: WHILE 2 = 2 DO + | SELECT 1; + | ITERATE lbl; + | END WHILE; + | END WHILE; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[WhileStatement]) + + val whileStmt = tree.collection.head.asInstanceOf[WhileStatement] + assert(whileStmt.condition.isInstanceOf[SingleStatement]) + assert(whileStmt.condition.getText == "1 = 1") + + assert(whileStmt.body.isInstanceOf[CompoundBody]) + assert(whileStmt.body.collection.length == 1) + + val nestedWhileStmt = whileStmt.body.collection.head.asInstanceOf[WhileStatement] + assert(nestedWhileStmt.condition.isInstanceOf[SingleStatement]) + assert(nestedWhileStmt.condition.getText == "2 = 2") + + assert(nestedWhileStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(nestedWhileStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(nestedWhileStmt.body.collection(1).isInstanceOf[IterateStatement]) + assert(nestedWhileStmt.body.collection(1).asInstanceOf[IterateStatement].label == "lbl") + } + + test("iterate outer loop from nested repeat loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: REPEAT + | lbl2: REPEAT + | SELECT 1; + | ITERATE lbl; + | UNTIL 2 = 2 + | END REPEAT; + | UNTIL 1 = 1 + | END REPEAT; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[RepeatStatement]) + + val repeatStmt = tree.collection.head.asInstanceOf[RepeatStatement] + assert(repeatStmt.condition.isInstanceOf[SingleStatement]) + assert(repeatStmt.condition.getText == "1 = 1") + + assert(repeatStmt.body.isInstanceOf[CompoundBody]) + assert(repeatStmt.body.collection.length == 1) + + val nestedRepeatStmt = repeatStmt.body.collection.head.asInstanceOf[RepeatStatement] + assert(nestedRepeatStmt.condition.isInstanceOf[SingleStatement]) + assert(nestedRepeatStmt.condition.getText == "2 = 2") + + assert(nestedRepeatStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert( + nestedRepeatStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(nestedRepeatStmt.body.collection(1).isInstanceOf[IterateStatement]) + assert(nestedRepeatStmt.body.collection(1).asInstanceOf[IterateStatement].label == "lbl") + } + + test("repeat") { + val sqlScriptText = + """BEGIN + |lbl: REPEAT + | SELECT 1; + | UNTIL 1 = 1 + |END REPEAT lbl; + |END + """.stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[RepeatStatement]) + + val repeatStmt = tree.collection.head.asInstanceOf[RepeatStatement] + assert(repeatStmt.condition.isInstanceOf[SingleStatement]) + assert(repeatStmt.condition.getText == "1 = 1") + + assert(repeatStmt.body.isInstanceOf[CompoundBody]) + assert(repeatStmt.body.collection.length == 1) + assert(repeatStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(repeatStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(repeatStmt.label.contains("lbl")) + } + + test("repeat with complex condition") { + val sqlScriptText = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |REPEAT + | SELECT 42; + |UNTIL + | (SELECT COUNT(*) < 2 FROM t) + |END REPEAT; + |END + |""".stripMargin + + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 2) + assert(tree.collection(1).isInstanceOf[RepeatStatement]) + + val repeatStmt = tree.collection(1).asInstanceOf[RepeatStatement] + assert(repeatStmt.condition.isInstanceOf[SingleStatement]) + assert(repeatStmt.condition.getText == "(SELECT COUNT(*) < 2 FROM t)") + + assert(repeatStmt.body.isInstanceOf[CompoundBody]) + assert(repeatStmt.body.collection.length == 1) + assert(repeatStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(repeatStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 42") + } + + test("repeat with if else block") { + val sqlScriptText = + """BEGIN + |lbl: REPEAT + | IF 1 = 1 THEN + | SELECT 1; + | ELSE + | SELECT 2; + | END IF; + |UNTIL + | 1 = 1 + |END REPEAT lbl; + |END + """.stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[RepeatStatement]) + + val whileStmt = tree.collection.head.asInstanceOf[RepeatStatement] + assert(whileStmt.condition.isInstanceOf[SingleStatement]) + assert(whileStmt.condition.getText == "1 = 1") + + assert(whileStmt.body.isInstanceOf[CompoundBody]) + assert(whileStmt.body.collection.length == 1) + assert(whileStmt.body.collection.head.isInstanceOf[IfElseStatement]) + val ifStmt = whileStmt.body.collection.head.asInstanceOf[IfElseStatement] + + assert(ifStmt.conditions.length == 1) + assert(ifStmt.conditionalBodies.length == 1) + assert(ifStmt.elseBody.isDefined) + + assert(ifStmt.conditions.head.isInstanceOf[SingleStatement]) + assert(ifStmt.conditions.head.getText == "1 = 1") + + assert(ifStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement]) + assert(ifStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 1") + + assert(ifStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement]) + assert(ifStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 2") + + assert(whileStmt.label.contains("lbl")) + } + + test("nested repeat") { + val sqlScriptText = + """BEGIN + |lbl: REPEAT + | REPEAT + | SELECT 42; + | UNTIL + | 2 = 2 + | END REPEAT; + |UNTIL + | 1 = 1 + |END REPEAT lbl; + |END + """.stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[RepeatStatement]) + + val whileStmt = tree.collection.head.asInstanceOf[RepeatStatement] + assert(whileStmt.condition.isInstanceOf[SingleStatement]) + assert(whileStmt.condition.getText == "1 = 1") + + assert(whileStmt.body.isInstanceOf[CompoundBody]) + assert(whileStmt.body.collection.length == 1) + assert(whileStmt.body.collection.head.isInstanceOf[RepeatStatement]) + val nestedWhileStmt = whileStmt.body.collection.head.asInstanceOf[RepeatStatement] + + assert(nestedWhileStmt.condition.isInstanceOf[SingleStatement]) + assert(nestedWhileStmt.condition.getText == "2 = 2") + + assert(nestedWhileStmt.body.isInstanceOf[CompoundBody]) + assert(nestedWhileStmt.body.collection.length == 1) + assert(nestedWhileStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(nestedWhileStmt.body.collection. + head.asInstanceOf[SingleStatement].getText == "SELECT 42") + + assert(whileStmt.label.contains("lbl")) + + } + + test("searched case statement") { + val sqlScriptText = + """ + |BEGIN + | CASE + | WHEN 1 = 1 THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CaseStatement]) + val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(caseStmt.conditions.length == 1) + assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditions.head.getText == "1 = 1") + } + + test("searched case statement - multi when") { + val sqlScriptText = + """ + |BEGIN + | CASE + | WHEN 1 IN (1,2,3) THEN + | SELECT 1; + | WHEN (SELECT * FROM t) THEN + | SELECT * FROM b; + | WHEN 1 = 1 THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + val tree = parseScript(sqlScriptText) + + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CaseStatement]) + + val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(caseStmt.conditions.length == 3) + assert(caseStmt.conditionalBodies.length == 3) + assert(caseStmt.elseBody.isEmpty) + + assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditions.head.getText == "1 IN (1,2,3)") + + assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 1") + + assert(caseStmt.conditions(1).isInstanceOf[SingleStatement]) + assert(caseStmt.conditions(1).getText == "(SELECT * FROM t)") + + assert(caseStmt.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT * FROM b") + + assert(caseStmt.conditions(2).isInstanceOf[SingleStatement]) + assert(caseStmt.conditions(2).getText == "1 = 1") + + assert(caseStmt.conditionalBodies(2).collection.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditionalBodies(2).collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 42") + } + + test("searched case statement with else") { + val sqlScriptText = + """ + |BEGIN + | CASE + | WHEN 1 = 1 THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CaseStatement]) + val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(caseStmt.elseBody.isDefined) + assert(caseStmt.conditions.length == 1) + assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditions.head.getText == "1 = 1") + + assert(caseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement]) + assert(caseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 43") + } + test("searched case statement nested") { + val sqlScriptText = + """ + |BEGIN + | CASE + | WHEN 1 = 1 THEN + | CASE + | WHEN 2 = 1 THEN + | SELECT 41; + | ELSE + | SELECT 42; + | END CASE; + | END CASE; + |END + |""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CaseStatement]) + + val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(caseStmt.conditions.length == 1) + assert(caseStmt.conditionalBodies.length == 1) + assert(caseStmt.elseBody.isEmpty) + + assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditions.head.getText == "1 = 1") + + assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[CaseStatement]) + val nestedCaseStmt = + caseStmt.conditionalBodies.head.collection.head.asInstanceOf[CaseStatement] + + assert(nestedCaseStmt.conditions.length == 1) + assert(nestedCaseStmt.conditionalBodies.length == 1) + assert(nestedCaseStmt.elseBody.isDefined) + + assert(nestedCaseStmt.conditions.head.isInstanceOf[SingleStatement]) + assert(nestedCaseStmt.conditions.head.getText == "2 = 1") + + assert(nestedCaseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement]) + assert(nestedCaseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 41") + + assert(nestedCaseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement]) + assert(nestedCaseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 42") + } + + test("simple case statement") { + val sqlScriptText = + """ + |BEGIN + | CASE 1 + | WHEN 1 THEN + | SELECT 1; + | END CASE; + |END + |""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CaseStatement]) + val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(caseStmt.conditions.length == 1) + assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) + checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1)) + } + + + test("simple case statement - multi when") { + val sqlScriptText = + """ + |BEGIN + | CASE 1 + | WHEN 1 THEN + | SELECT 1; + | WHEN (SELECT 2) THEN + | SELECT * FROM b; + | WHEN 3 IN (1,2,3) THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + val tree = parseScript(sqlScriptText) + + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CaseStatement]) + + val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(caseStmt.conditions.length == 3) + assert(caseStmt.conditionalBodies.length == 3) + assert(caseStmt.elseBody.isEmpty) + + assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) + checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1)) + + assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 1") + + assert(caseStmt.conditions(1).isInstanceOf[SingleStatement]) + checkSimpleCaseStatementCondition( + caseStmt.conditions(1), _ == Literal(1), _.isInstanceOf[ScalarSubquery]) + + assert(caseStmt.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT * FROM b") + + assert(caseStmt.conditions(2).isInstanceOf[SingleStatement]) + checkSimpleCaseStatementCondition( + caseStmt.conditions(2), _ == Literal(1), _.isInstanceOf[In]) + + assert(caseStmt.conditionalBodies(2).collection.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditionalBodies(2).collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 42") + } + + test("simple case statement with else") { + val sqlScriptText = + """ + |BEGIN + | CASE 1 + | WHEN 1 THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CaseStatement]) + val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(caseStmt.elseBody.isDefined) + assert(caseStmt.conditions.length == 1) + assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) + checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1)) + + assert(caseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement]) + assert(caseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 43") + } + + test("simple case statement nested") { + val sqlScriptText = + """ + |BEGIN + | CASE (SELECT 1) + | WHEN 1 THEN + | CASE 2 + | WHEN 2 THEN + | SELECT 41; + | ELSE + | SELECT 42; + | END CASE; + | END CASE; + |END + |""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CaseStatement]) + + val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(caseStmt.conditions.length == 1) + assert(caseStmt.conditionalBodies.length == 1) + assert(caseStmt.elseBody.isEmpty) + + assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) + checkSimpleCaseStatementCondition( + caseStmt.conditions.head, _.isInstanceOf[ScalarSubquery], _ == Literal(1)) + + assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[CaseStatement]) + val nestedCaseStmt = + caseStmt.conditionalBodies.head.collection.head.asInstanceOf[CaseStatement] + + assert(nestedCaseStmt.conditions.length == 1) + assert(nestedCaseStmt.conditionalBodies.length == 1) + assert(nestedCaseStmt.elseBody.isDefined) + + assert(nestedCaseStmt.conditions.head.isInstanceOf[SingleStatement]) + checkSimpleCaseStatementCondition( + nestedCaseStmt.conditions.head, _ == Literal(2), _ == Literal(2)) + + assert(nestedCaseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement]) + assert(nestedCaseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 41") + + assert(nestedCaseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement]) + assert(nestedCaseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 42") } // Helper methods @@ -677,4 +1408,17 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { .replace("END", "") .trim } + + private def checkSimpleCaseStatementCondition( + conditionStatement: SingleStatement, + predicateLeft: Expression => Boolean, + predicateRight: Expression => Boolean): Unit = { + assert(conditionStatement.parsedPlan.isInstanceOf[Project]) + val project = conditionStatement.parsedPlan.asInstanceOf[Project] + assert(project.projectList.head.isInstanceOf[Alias]) + assert(project.projectList.head.asInstanceOf[Alias].child.isInstanceOf[EqualTo]) + val equalTo = project.projectList.head.asInstanceOf[Alias].child.asInstanceOf[EqualTo] + assert(predicateLeft(equalTo.left)) + assert(predicateRight(equalTo.right)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index 62557ead1d2ee..0f32922728814 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -296,10 +296,10 @@ class TableIdentifierParserSuite extends SQLKeywordUtils { "t:" -> ("PARSE_SYNTAX_ERROR", Map("error" -> "':'", "hint" -> ": extra input ':'")), "${some.var.x}" -> ("PARSE_SYNTAX_ERROR", Map("error" -> "'$'", "hint" -> "")), "tab:1" -> ("PARSE_SYNTAX_ERROR", Map("error" -> "':'", "hint" -> "")) - ).foreach { case (identifier, (errorClass, parameters)) => + ).foreach { case (identifier, (condition, parameters)) => checkError( exception = intercept[ParseException](parseTableIdentifier(identifier)), - errorClass = errorClass, + condition = condition, parameters = parameters) } } @@ -318,7 +318,7 @@ class TableIdentifierParserSuite extends SQLKeywordUtils { reservedKeywordsInAnsiMode.foreach { keyword => checkError( exception = intercept[ParseException](parseTableIdentifier(keyword)), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> s"'$keyword'", "hint" -> "")) assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`")) assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`")) @@ -374,7 +374,7 @@ class TableIdentifierParserSuite extends SQLKeywordUtils { assert(complexName === parseTableIdentifier(complexName.quotedString)) checkError( exception = intercept[ParseException](parseTableIdentifier(complexName.unquotedString)), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'b'", "hint" -> "")) // Table identifier contains continuous backticks should be treated correctly. val complexName2 = TableIdentifier("x``y", Some("d``b")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala index a56ab8616df0f..74fb5a44ab0bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala @@ -75,26 +75,26 @@ class TableSchemaParserSuite extends SparkFunSuite { checkError( exception = parseException(""), - errorClass = "PARSE_EMPTY_STATEMENT") + condition = "PARSE_EMPTY_STATEMENT") checkError( exception = parseException("a"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "end of input", "hint" -> "")) checkError( exception = parseException("a INT b long"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'b'", "hint" -> "")) checkError( exception = parseException("a INT,, b long"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "','", "hint" -> "")) checkError( exception = parseException("a INT, b long,,"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "','", "hint" -> "")) checkError( exception = parseException("a INT, b long, c int,"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "end of input", "hint" -> "")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala index 3012ef6f1544d..3f59f8de95429 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala @@ -31,8 +31,8 @@ class UnpivotParserSuite extends AnalysisTest { comparePlans(parsePlan(sqlCommand), plan, checkAnalysis = false) } - private def intercept(sqlCommand: String, errorClass: Option[String], messages: String*): Unit = - interceptParseException(parsePlan)(sqlCommand, messages: _*)(errorClass) + private def intercept(sqlCommand: String, condition: Option[String], messages: String*): Unit = + interceptParseException(parsePlan)(sqlCommand, messages: _*)(condition) test("unpivot - single value") { assertEqual( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/JoinTypesTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/JoinTypesTest.scala index 886b043ad79e6..7fa1935ccb058 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/JoinTypesTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/JoinTypesTest.scala @@ -68,7 +68,7 @@ class JoinTypesTest extends SparkFunSuite { exception = intercept[AnalysisException]( JoinType(joinType) ), - errorClass = "UNSUPPORTED_JOIN_TYPE", + condition = "UNSUPPORTED_JOIN_TYPE", sqlState = "0A000", parameters = Map( "typ" -> joinType, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala index 55afbc3acb096..3a739ccbecb64 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala @@ -41,7 +41,7 @@ class InternalOutputModesSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { InternalOutputModes(outputMode) }, - errorClass = "STREAMING_OUTPUT_MODE.INVALID", + condition = "STREAMING_OUTPUT_MODE.INVALID", parameters = Map("outputMode" -> outputMode)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala index 1d3fb835f5a77..0e872dcdb6262 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala @@ -43,7 +43,7 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper { builder.put(1, null) // null value is OK checkError( exception = intercept[SparkRuntimeException](builder.put(null, 1)), - errorClass = "NULL_MAP_KEY", + condition = "NULL_MAP_KEY", parameters = Map.empty ) } @@ -53,7 +53,7 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper { builder.put(1, 1) checkError( exception = intercept[SparkRuntimeException](builder.put(1, 2)), - errorClass = "DUPLICATED_MAP_KEY", + condition = "DUPLICATED_MAP_KEY", parameters = Map( "key" -> "1", "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") @@ -65,7 +65,7 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper { builderDouble.put(-0.0, 1) checkError( exception = intercept[SparkRuntimeException](builderDouble.put(0.0, 2)), - errorClass = "DUPLICATED_MAP_KEY", + condition = "DUPLICATED_MAP_KEY", parameters = Map( "key" -> "0.0", "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") @@ -110,7 +110,7 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper { val arr = Array(1.toByte) checkError( exception = intercept[SparkRuntimeException](builder.put(arr, 3)), - errorClass = "DUPLICATED_MAP_KEY", + condition = "DUPLICATED_MAP_KEY", parameters = Map( "key" -> arr.toString, "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") @@ -146,7 +146,7 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper { // By default duplicated map key fails the query. checkError( exception = intercept[SparkRuntimeException](builder.put(unsafeRow, 3)), - errorClass = "DUPLICATED_MAP_KEY", + condition = "DUPLICATED_MAP_KEY", parameters = Map( "key" -> "[0,1]", "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") @@ -180,7 +180,7 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper { // By default duplicated map key fails the query. checkError( exception = intercept[SparkRuntimeException](builder.put(unsafeArray, 3)), - errorClass = "DUPLICATED_MAP_KEY", + condition = "DUPLICATED_MAP_KEY", parameters = Map( "key" -> unsafeArray.toString, "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala index d55e672079484..632109a0cc8d0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala @@ -58,7 +58,7 @@ class ArrayDataIndexedSeqSuite extends SparkFunSuite { exception = intercept[SparkException] { seq(index) }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map( "message" -> s"Index $index must be between 0 and the length of the ArrayData.")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelperSuite.scala index 034010f5825b8..095dc3869571d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelperSuite.scala @@ -42,7 +42,7 @@ class DateTimeFormatterHelperSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { convertIncompatiblePattern(s"yyyy-MM-dd $l G") }, - errorClass = "INCONSISTENT_BEHAVIOR_CROSS_VERSION.DATETIME_WEEK_BASED_PATTERN", + condition = "INCONSISTENT_BEHAVIOR_CROSS_VERSION.DATETIME_WEEK_BASED_PATTERN", parameters = Map("c" -> l.toString)) } unsupportedLetters.foreach { l => @@ -50,7 +50,7 @@ class DateTimeFormatterHelperSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { convertIncompatiblePattern(s"yyyy-MM-dd $l G") }, - errorClass = "INVALID_DATETIME_PATTERN.ILLEGAL_CHARACTER", + condition = "INVALID_DATETIME_PATTERN.ILLEGAL_CHARACTER", parameters = Map( "c" -> l.toString, "pattern" -> s"yyyy-MM-dd $l G")) @@ -60,7 +60,7 @@ class DateTimeFormatterHelperSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { DateTimeFormatterHelper.convertIncompatiblePattern(s"$l", isParsing = true) }, - errorClass = "INVALID_DATETIME_PATTERN.ILLEGAL_CHARACTER", + condition = "INVALID_DATETIME_PATTERN.ILLEGAL_CHARACTER", parameters = Map( "c" -> l.toString, "pattern" -> s"$l")) @@ -70,17 +70,27 @@ class DateTimeFormatterHelperSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { convertIncompatiblePattern(s"yyyy-MM-dd $style") }, - errorClass = "INVALID_DATETIME_PATTERN.LENGTH", + condition = "INVALID_DATETIME_PATTERN.LENGTH", parameters = Map("pattern" -> style)) checkError( exception = intercept[SparkIllegalArgumentException] { convertIncompatiblePattern(s"yyyy-MM-dd $style${style.head}") }, - errorClass = "INVALID_DATETIME_PATTERN.LENGTH", + condition = "INVALID_DATETIME_PATTERN.LENGTH", parameters = Map("pattern" -> style)) } assert(convertIncompatiblePattern("yyyy-MM-dd EEEE") === "uuuu-MM-dd EEEE") assert(convertIncompatiblePattern("yyyy-MM-dd'e'HH:mm:ss") === "uuuu-MM-dd'e'HH:mm:ss") assert(convertIncompatiblePattern("yyyy-MM-dd'T'") === "uuuu-MM-dd'T'") } + + test("SPARK-49583: invalid var length second fraction") { + val pattern = "\nSSSS\r" + checkError( + exception = intercept[SparkIllegalArgumentException] { + createBuilderWithVarLengthSecondFraction(pattern) + }, + condition = "INVALID_DATETIME_PATTERN.SECONDS_FRACTION", + parameters = Map("pattern" -> pattern)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 8d8669aece894..96aaf13052b02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -542,7 +542,7 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { checkError( exception = intercept[SparkIllegalArgumentException]( dateAddInterval(input, new CalendarInterval(36, 47, 1))), - errorClass = "_LEGACY_ERROR_TEMP_2000", + condition = "_LEGACY_ERROR_TEMP_2000", parameters = Map( "message" -> "Cannot add hours, minutes or seconds, milliseconds, microseconds to a date", "ansiConfig" -> "\"spark.sql.ansi.enabled\"")) @@ -896,13 +896,13 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { exception = intercept[SparkIllegalArgumentException] { getDayOfWeekFromString(UTF8String.fromString("xx")) }, - errorClass = "_LEGACY_ERROR_TEMP_3209", + condition = "ILLEGAL_DAY_OF_WEEK", parameters = Map("string" -> "xx")) checkError( exception = intercept[SparkIllegalArgumentException] { getDayOfWeekFromString(UTF8String.fromString("\"quote")) }, - errorClass = "_LEGACY_ERROR_TEMP_3209", + condition = "ILLEGAL_DAY_OF_WEEK", parameters = Map("string" -> "\"quote")) } @@ -1043,7 +1043,7 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { exception = intercept[SparkIllegalArgumentException] { timestampAdd("SECS", 1, date(1969, 1, 1, 0, 0, 0, 1, getZoneId("UTC")), getZoneId("UTC")) }, - errorClass = "INVALID_PARAMETER_VALUE.DATETIME_UNIT", + condition = "INVALID_PARAMETER_VALUE.DATETIME_UNIT", parameters = Map( "functionName" -> "`TIMESTAMPADD`", "parameter" -> "`unit`", @@ -1102,7 +1102,7 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { date(2022, 1, 1, 0, 0, 0, 1, getZoneId("UTC")), getZoneId("UTC")) }, - errorClass = "INVALID_PARAMETER_VALUE.DATETIME_UNIT", + condition = "INVALID_PARAMETER_VALUE.DATETIME_UNIT", parameters = Map("functionName" -> "`TIMESTAMPDIFF`", "parameter" -> "`unit`", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala index 001ae2728d10f..700dfe30a2389 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala @@ -44,7 +44,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { stringToInterval(UTF8String.fromString(input)) }, - errorClass = "INVALID_INTERVAL_FORMAT.INPUT_IS_NULL", + condition = "INVALID_INTERVAL_FORMAT.INPUT_IS_NULL", parameters = Map( "input" -> Option(input).map(_.toString).getOrElse("null"))) assert(safeStringToInterval(UTF8String.fromString(input)) === null) @@ -55,7 +55,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { stringToInterval(UTF8String.fromString(input)) }, - errorClass = "INVALID_INTERVAL_FORMAT.INPUT_IS_EMPTY", + condition = "INVALID_INTERVAL_FORMAT.INPUT_IS_EMPTY", parameters = Map( "input" -> Option(input).map(_.toString).getOrElse("null"))) assert(safeStringToInterval(UTF8String.fromString(input)) === null) @@ -66,7 +66,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { stringToInterval(UTF8String.fromString(input)) }, - errorClass = "INVALID_INTERVAL_FORMAT.INVALID_PREFIX", + condition = "INVALID_INTERVAL_FORMAT.INVALID_PREFIX", parameters = Map( "input" -> Option(input).map(_.toString).getOrElse("null"), "prefix" -> prefix)) @@ -78,7 +78,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { stringToInterval(UTF8String.fromString(input)) }, - errorClass = "INVALID_INTERVAL_FORMAT.UNRECOGNIZED_NUMBER", + condition = "INVALID_INTERVAL_FORMAT.UNRECOGNIZED_NUMBER", parameters = Map( "input" -> Option(input).map(_.toString).getOrElse("null"), "number" -> number)) @@ -90,7 +90,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { stringToInterval(UTF8String.fromString(input)) }, - errorClass = "INVALID_INTERVAL_FORMAT.ARITHMETIC_EXCEPTION", + condition = "INVALID_INTERVAL_FORMAT.ARITHMETIC_EXCEPTION", parameters = Map( "input" -> Option(input).map(_.toString).getOrElse("null"))) assert(safeStringToInterval(UTF8String.fromString(input)) === null) @@ -101,7 +101,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { stringToInterval(UTF8String.fromString(input)) }, - errorClass = "INVALID_INTERVAL_FORMAT.INVALID_VALUE", + condition = "INVALID_INTERVAL_FORMAT.INVALID_VALUE", parameters = Map( "input" -> Option(input).map(_.toString).getOrElse("null"), "value" -> value)) @@ -113,7 +113,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { stringToInterval(UTF8String.fromString(input)) }, - errorClass = "INVALID_INTERVAL_FORMAT.INVALID_PRECISION", + condition = "INVALID_INTERVAL_FORMAT.INVALID_PRECISION", parameters = Map( "input" -> Option(input).map(_.toString).getOrElse("null"), "value" -> value)) @@ -125,7 +125,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { stringToInterval(UTF8String.fromString(input)) }, - errorClass = "INVALID_INTERVAL_FORMAT.INVALID_FRACTION", + condition = "INVALID_INTERVAL_FORMAT.INVALID_FRACTION", parameters = Map( "input" -> Option(input).map(_.toString).getOrElse("null"), "unit" -> unit)) @@ -137,7 +137,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { stringToInterval(UTF8String.fromString(input)) }, - errorClass = "INVALID_INTERVAL_FORMAT.INVALID_UNIT", + condition = "INVALID_INTERVAL_FORMAT.INVALID_UNIT", parameters = Map( "input" -> Option(input).map(_.toString).getOrElse("null"), "unit" -> unit)) @@ -149,7 +149,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { stringToInterval(UTF8String.fromString(input)) }, - errorClass = "INVALID_INTERVAL_FORMAT.MISSING_NUMBER", + condition = "INVALID_INTERVAL_FORMAT.MISSING_NUMBER", parameters = Map( "input" -> Option(input).map(_.toString).getOrElse("null"), "word" -> word)) @@ -161,7 +161,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { stringToInterval(UTF8String.fromString(input)) }, - errorClass = "INVALID_INTERVAL_FORMAT.MISSING_UNIT", + condition = "INVALID_INTERVAL_FORMAT.MISSING_UNIT", parameters = Map( "input" -> Option(input).map(_.toString).getOrElse("null"), "word" -> word)) @@ -173,7 +173,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkIllegalArgumentException] { stringToInterval(UTF8String.fromString(input)) }, - errorClass = "INVALID_INTERVAL_FORMAT.UNKNOWN_PARSING_ERROR", + condition = "INVALID_INTERVAL_FORMAT.UNKNOWN_PARSING_ERROR", parameters = Map( "input" -> Option(input).map(_.toString).getOrElse("null"), "word" -> word)) @@ -295,7 +295,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { assert(fromYearMonthString("99-10") === new CalendarInterval(99 * 12 + 10, 0, 0L)) assert(fromYearMonthString("+99-10") === new CalendarInterval(99 * 12 + 10, 0, 0L)) assert(fromYearMonthString("-8-10") === new CalendarInterval(-8 * 12 - 10, 0, 0L)) - failFuncWithInvalidInput("99-15", "month 15 outside range", fromYearMonthString) + failFuncWithInvalidInput("99-15", "year-month", fromYearMonthString) failFuncWithInvalidInput("9a9-15", "Interval string does not match year-month format", fromYearMonthString) @@ -314,12 +314,12 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { val e1 = intercept[IllegalArgumentException]{ assert(fromYearMonthString("178956970-8") == new CalendarInterval(Int.MinValue, 0, 0)) }.getMessage - assert(e1.contains("integer overflow")) + assert(e1.contains("year-month")) assert(fromYearMonthString("-178956970-8") == new CalendarInterval(Int.MinValue, 0, 0)) val e2 = intercept[IllegalArgumentException]{ assert(fromYearMonthString("-178956970-9") == new CalendarInterval(Int.MinValue, 0, 0)) }.getMessage - assert(e2.contains("integer overflow")) + assert(e2.contains("year-month")) } test("from day-time string - legacy") { @@ -338,6 +338,29 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { 12 * MICROS_PER_MINUTE + millisToMicros(888))) assert(fromDayTimeString("-3 0:0:0") === new CalendarInterval(0, -3, 0L)) + checkError( + exception = intercept[SparkIllegalArgumentException] { + fromDayTimeString("5 30:12:20") + }, + parameters = Map( + "msg" -> "requirement failed: hour 30 outside range [0, 23]", + "input" -> "5 30:12:20"), + condition = "INVALID_INTERVAL_FORMAT.DAY_TIME_PARSING", + sqlState = Some("22006") + ) + + checkError( + exception = intercept[SparkIllegalArgumentException] { + fromDayTimeString("5 12:40:30.999999999", 0, 0) + }, + parameters = Map( + "from" -> "day", + "to" -> "day", + "input" -> "5 12:40:30.999999999"), + condition = "INVALID_INTERVAL_FORMAT.UNSUPPORTED_FROM_TO_EXPRESSION", + sqlState = Some("22006") + ) + failFuncWithInvalidInput("5 30:12:20", "hour 30 outside range", fromDayTimeString) failFuncWithInvalidInput("5 30-12", "must match day-time format", fromDayTimeString) } @@ -379,6 +402,17 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { assert(negate(new CalendarInterval(1, 2, 3)) === new CalendarInterval(-1, -2, -3)) } + test("parsing second_nano string") { + checkError( + exception = intercept[SparkIllegalArgumentException] { + toDTInterval("12", "33.33.33", 1) + }, + condition = "INVALID_INTERVAL_FORMAT.SECOND_NANO_FORMAT", + parameters = Map("input" -> "33.33.33"), + sqlState = Some("22006") + ) + } + test("subtract one interval by another") { val input1 = new CalendarInterval(3, 1, 1 * MICROS_PER_HOUR) val input2 = new CalendarInterval(2, 4, 100 * MICROS_PER_HOUR) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TimestampFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TimestampFormatterSuite.scala index 6223f9aadb593..558d7eda78b4a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TimestampFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TimestampFormatterSuite.scala @@ -553,7 +553,7 @@ class TimestampFormatterSuite extends DatetimeFormatterSuite { exception = intercept[SparkException] { formatter.parseWithoutTimeZone(invalidTimestampStr, allowTimeZone = false) }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map( "message" -> ("Cannot parse field value '2021-13-01T25:61:61' for pattern " + "'yyyy-MM-dd HH:mm:ss' as the target spark data type \"TIMESTAMP_NTZ\".")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala index e20dfd4f60512..aca6931a0688d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala @@ -362,7 +362,7 @@ class CatalogSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { catalog.alterTable(testIdent, TableChange.addColumn(Array("data", "ts"), TimestampType)) }, - errorClass = "_LEGACY_ERROR_TEMP_3229", + condition = "_LEGACY_ERROR_TEMP_3229", parameters = Map("name" -> "data")) // the table has not changed @@ -381,7 +381,7 @@ class CatalogSuite extends SparkFunSuite { catalog.alterTable(testIdent, TableChange.addColumn(Array("missing_col", "new_field"), StringType)) }, - errorClass = "_LEGACY_ERROR_TEMP_3227", + condition = "_LEGACY_ERROR_TEMP_3227", parameters = Map("fieldName" -> "missing_col")) } @@ -427,7 +427,7 @@ class CatalogSuite extends SparkFunSuite { catalog.alterTable(testIdent, TableChange.updateColumnType(Array("missing_col"), LongType)) }, - errorClass = "_LEGACY_ERROR_TEMP_3227", + condition = "_LEGACY_ERROR_TEMP_3227", parameters = Map("fieldName" -> "missing_col")) } @@ -478,7 +478,7 @@ class CatalogSuite extends SparkFunSuite { catalog.alterTable(testIdent, TableChange.updateColumnComment(Array("missing_col"), "comment")) }, - errorClass = "_LEGACY_ERROR_TEMP_3227", + condition = "_LEGACY_ERROR_TEMP_3227", parameters = Map("fieldName" -> "missing_col")) } @@ -546,7 +546,7 @@ class CatalogSuite extends SparkFunSuite { catalog.alterTable(testIdent, TableChange.renameColumn(Array("missing_col"), "new_name")) }, - errorClass = "_LEGACY_ERROR_TEMP_3227", + condition = "_LEGACY_ERROR_TEMP_3227", parameters = Map("fieldName" -> "missing_col")) } @@ -614,7 +614,7 @@ class CatalogSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col"), false)) }, - errorClass = "_LEGACY_ERROR_TEMP_3227", + condition = "_LEGACY_ERROR_TEMP_3227", parameters = Map("fieldName" -> "missing_col")) // with if exists it should pass @@ -636,7 +636,7 @@ class CatalogSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z"), false)) }, - errorClass = "_LEGACY_ERROR_TEMP_3227", + condition = "_LEGACY_ERROR_TEMP_3227", parameters = Map("fieldName" -> "z")) // with if exists it should pass diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala index 8d8d2317f0986..411a88b8765f6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala @@ -24,10 +24,13 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.analysis.{NoSuchFunctionException, NoSuchNamespaceException} import org.apache.spark.sql.connector.catalog.functions.UnboundFunction +import org.apache.spark.sql.connector.catalog.procedures.UnboundProcedure -class InMemoryCatalog extends InMemoryTableCatalog with FunctionCatalog { +class InMemoryCatalog extends InMemoryTableCatalog with FunctionCatalog with ProcedureCatalog { protected val functions: util.Map[Identifier, UnboundFunction] = new ConcurrentHashMap[Identifier, UnboundFunction]() + protected val procedures: util.Map[Identifier, UnboundProcedure] = + new ConcurrentHashMap[Identifier, UnboundProcedure]() override protected def allNamespaces: Seq[Seq[String]] = { (tables.keySet.asScala.map(_.namespace.toSeq) ++ @@ -63,4 +66,18 @@ class InMemoryCatalog extends InMemoryTableCatalog with FunctionCatalog { def clearFunctions(): Unit = { functions.clear() } + + override def loadProcedure(ident: Identifier): UnboundProcedure = { + val procedure = procedures.get(ident) + if (procedure == null) throw new RuntimeException("Procedure not found: " + ident) + procedure + } + + def createProcedure(ident: Identifier, procedure: UnboundProcedure): UnboundProcedure = { + procedures.put(ident, procedure) + } + + def clearProcedures(): Unit = { + procedures.clear() + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index 982de88e58847..56ed3bb243e19 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -167,7 +167,8 @@ class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamesp override def capabilities: java.util.Set[TableCatalogCapability] = { Set( TableCatalogCapability.SUPPORT_COLUMN_DEFAULT_VALUE, - TableCatalogCapability.SUPPORTS_CREATE_TABLE_WITH_GENERATED_COLUMNS + TableCatalogCapability.SUPPORTS_CREATE_TABLE_WITH_GENERATED_COLUMNS, + TableCatalogCapability.SUPPORTS_CREATE_TABLE_WITH_IDENTITY_COLUMNS ).asJava } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala index 1aa0b408366bf..a9d8a69128ae2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala @@ -121,7 +121,7 @@ class SupportsAtomicPartitionManagementSuite extends SparkFunSuite { exception = intercept[SparkUnsupportedOperationException] { partTable.purgePartitions(partIdents) }, - errorClass = "UNSUPPORTED_FEATURE.PURGE_PARTITION", + condition = "UNSUPPORTED_FEATURE.PURGE_PARTITION", parameters = Map.empty ) } @@ -170,7 +170,7 @@ class SupportsAtomicPartitionManagementSuite extends SparkFunSuite { partTable.truncatePartitions(Array(InternalRow("5"), InternalRow("6"))) } checkError(e, - errorClass = "PARTITIONS_NOT_FOUND", + condition = "PARTITIONS_NOT_FOUND", parameters = Map("partitionList" -> "PARTITION (`dt` = 6)", "tableName" -> "`test`.`ns`.`test_table`")) assert(partTable.rows === InternalRow(2, "zyx", "5") :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala index 06a23e7fda207..8581d4dec1fb8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala @@ -93,7 +93,7 @@ class SupportsPartitionManagementSuite extends SparkFunSuite { exception = intercept[SparkUnsupportedOperationException] { partTable.purgePartition(InternalRow.apply("3")) }, - errorClass = "UNSUPPORTED_FEATURE.PURGE_PARTITION", + condition = "UNSUPPORTED_FEATURE.PURGE_PARTITION", parameters = Map.empty ) } @@ -217,7 +217,7 @@ class SupportsPartitionManagementSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { partTable.partitionExists(InternalRow(0)) }, - errorClass = "_LEGACY_ERROR_TEMP_3208", + condition = "_LEGACY_ERROR_TEMP_3208", parameters = Map("numFields" -> "1", "schemaLen" -> "2")) } @@ -228,7 +228,7 @@ class SupportsPartitionManagementSuite extends SparkFunSuite { partTable.renamePartition(InternalRow(0, "abc"), InternalRow(1, "abc")) } checkError(e, - errorClass = "PARTITIONS_ALREADY_EXIST", + condition = "PARTITIONS_ALREADY_EXIST", parameters = Map("partitionList" -> "PARTITION (`part0` = 1, `part1` = abc)", "tableName" -> "`test`.`ns`.`test_table`")) @@ -237,7 +237,7 @@ class SupportsPartitionManagementSuite extends SparkFunSuite { partTable.renamePartition(newPart, InternalRow(3, "abc")) } checkError(e2, - errorClass = "PARTITIONS_NOT_FOUND", + condition = "PARTITIONS_NOT_FOUND", parameters = Map("partitionList" -> "PARTITION (`part0` = 2, `part1` = xyz)", "tableName" -> "`test`.`ns`.`test_table`")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 4343e464b2c80..3241f031a706b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -161,7 +161,7 @@ class DataTypeSuite extends SparkFunSuite { exception = intercept[SparkException] { left.merge(right) }, - errorClass = "CANNOT_MERGE_INCOMPATIBLE_DATA_TYPE", + condition = "CANNOT_MERGE_INCOMPATIBLE_DATA_TYPE", parameters = Map("left" -> "\"FLOAT\"", "right" -> "\"BIGINT\"" ) ) @@ -299,21 +299,21 @@ class DataTypeSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { DataType.fromJson(""""abcd"""") }, - errorClass = "INVALID_JSON_DATA_TYPE", + condition = "INVALID_JSON_DATA_TYPE", parameters = Map("invalidType" -> "abcd")) checkError( exception = intercept[SparkIllegalArgumentException] { DataType.fromJson("""{"abcd":"a"}""") }, - errorClass = "INVALID_JSON_DATA_TYPE", + condition = "INVALID_JSON_DATA_TYPE", parameters = Map("invalidType" -> """{"abcd":"a"}""")) checkError( exception = intercept[SparkIllegalArgumentException] { DataType.fromJson("""{"fields": [{"a":123}], "type": "struct"}""") }, - errorClass = "INVALID_JSON_DATA_TYPE", + condition = "INVALID_JSON_DATA_TYPE", parameters = Map("invalidType" -> """{"a":123}""")) // Malformed JSON string @@ -900,7 +900,7 @@ class DataTypeSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { DataType.fromJson(json) }, - errorClass = "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS", + condition = "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS", parameters = Map("jsonType" -> "integer") ) } @@ -934,7 +934,7 @@ class DataTypeSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { DataType.fromJson(json) }, - errorClass = "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS", + condition = "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS", parameters = Map("jsonType" -> "integer") ) } @@ -968,7 +968,7 @@ class DataTypeSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { DataType.fromJson(json) }, - errorClass = "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS", + condition = "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS", parameters = Map("jsonType" -> "map") ) } @@ -997,7 +997,7 @@ class DataTypeSuite extends SparkFunSuite { exception = intercept[SparkException] { DataType.fromJson(json) }, - errorClass = "COLLATION_INVALID_PROVIDER", + condition = "COLLATION_INVALID_PROVIDER", parameters = Map("provider" -> "badProvider", "supportedProviders" -> "spark, icu") ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala index 8c9196cc33ca5..f07ee8b35bbb2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala @@ -41,7 +41,7 @@ class StrictDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBa DataTypeUtils.canWrite("", widerPoint2, point2, true, analysis.caseSensitiveResolution, "t", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "``", "colName" -> "`t`.`x`", @@ -60,7 +60,7 @@ class StrictDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBa DataTypeUtils.canWrite("", arrayOfLong, arrayOfInt, true, analysis.caseSensitiveResolution, "arr", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "``", "colName" -> "`arr`.`element`", @@ -79,7 +79,7 @@ class StrictDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBa DataTypeUtils.canWrite("", mapOfLong, mapOfInt, true, analysis.caseSensitiveResolution, "m", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "``", "colName" -> "`m`.`value`", @@ -98,7 +98,7 @@ class StrictDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBa DataTypeUtils.canWrite("", mapKeyLong, mapKeyInt, true, analysis.caseSensitiveResolution, "m", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "``", "colName" -> "`m`.`key`", @@ -116,7 +116,7 @@ class StrictDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBa analysis.caseSensitiveResolution, "nulls", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "``", "colName" -> "`nulls`", @@ -143,7 +143,7 @@ class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBase DataTypeUtils.canWrite("", mapOfString, mapOfInt, true, analysis.caseSensitiveResolution, "m", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "``", "colName" -> "`m`.`value`", @@ -163,7 +163,7 @@ class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBase DataTypeUtils.canWrite("", stringPoint2, point2, true, analysis.caseSensitiveResolution, "t", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "``", "colName" -> "`t`.`x`", @@ -182,7 +182,7 @@ class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBase DataTypeUtils.canWrite("", arrayOfString, arrayOfInt, true, analysis.caseSensitiveResolution, "arr", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "``", "colName" -> "`arr`.`element`", @@ -201,7 +201,7 @@ class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBase DataTypeUtils.canWrite("", mapKeyString, mapKeyInt, true, analysis.caseSensitiveResolution, "arr", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "``", "colName" -> "`arr`.`key`", @@ -218,7 +218,7 @@ class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBase analysis.caseSensitiveResolution, "longToTimestamp", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "``", "colName" -> "`longToTimestamp`", @@ -231,7 +231,7 @@ class ANSIDataTypeWriteCompatibilitySuite extends DataTypeWriteCompatibilityBase analysis.caseSensitiveResolution, "timestampToLong", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "``", "colName" -> "`timestampToLong`", @@ -306,7 +306,7 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { DataTypeUtils.canWrite("", w, r, true, analysis.caseSensitiveResolution, "t", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "``", "colName" -> "`t`", @@ -328,7 +328,7 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { DataTypeUtils.canWrite("", missingRequiredField, point2, true, analysis.caseSensitiveResolution, "t", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.STRUCT_MISSING_FIELDS", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.STRUCT_MISSING_FIELDS", parameters = Map("tableName" -> "``", "colName" -> "`t`", "missingFields" -> "`y`") ) } @@ -341,7 +341,7 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { DataTypeUtils.canWrite("", missingRequiredField, point2, true, analysis.caseSensitiveResolution, "t", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.UNEXPECTED_COLUMN_NAME", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.UNEXPECTED_COLUMN_NAME", parameters = Map( "expected" -> "`x`", "found" -> "`y`", @@ -369,7 +369,7 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { DataTypeUtils.canWrite("", missingMiddleField, expectedStruct, true, analysis.caseSensitiveResolution, "t", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.UNEXPECTED_COLUMN_NAME", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.UNEXPECTED_COLUMN_NAME", parameters = Map( "expected" -> "`y`", "found" -> "`z`", @@ -406,7 +406,7 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { DataTypeUtils.canWrite("", requiredFieldIsOptional, point2, true, analysis.caseSensitiveResolution, "t", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.NULLABLE_COLUMN", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.NULLABLE_COLUMN", parameters = Map("tableName" -> "``", "colName" -> "`t`.`x`") ) } @@ -418,7 +418,7 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { DataTypeUtils.canWrite("", point3, point2, true, analysis.caseSensitiveResolution, "t", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS", parameters = Map("tableName" -> "``", "colName" -> "`t`", "extraFields" -> "`z`") ) } @@ -459,7 +459,7 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { DataTypeUtils.canWrite("", arrayOfOptional, arrayOfRequired, true, analysis.caseSensitiveResolution, "arr", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.NULLABLE_ARRAY_ELEMENTS", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.NULLABLE_ARRAY_ELEMENTS", parameters = Map("tableName" -> "``", "colName" -> "`arr`") ) } @@ -489,7 +489,7 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { DataTypeUtils.canWrite("", mapOfOptional, mapOfRequired, true, analysis.caseSensitiveResolution, "m", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.NULLABLE_MAP_VALUES", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.NULLABLE_MAP_VALUES", parameters = Map("tableName" -> "``", "colName" -> "`m`") ) } @@ -560,7 +560,7 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { DataTypeUtils.canWrite("", sqlType, udtType, true, analysis.caseSensitiveResolution, "t", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.NULLABLE_COLUMN", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.NULLABLE_COLUMN", parameters = Map( "tableName" -> "``", "colName" -> "`t`.`col2`" @@ -595,7 +595,7 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { DataTypeUtils.canWrite("", udtType, sqlType, true, analysis.caseSensitiveResolution, "t", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "``", "colName" -> "`t`.`col2`", @@ -633,7 +633,7 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { DataTypeUtils.canWrite("", sqlType, udtType, true, analysis.caseSensitiveResolution, "t", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "``", "colName" -> "`t`.`col2`", @@ -675,7 +675,7 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { DataTypeUtils.canWrite("", writeType, readType, true, analysis.caseSensitiveResolution, "t", storeAssignmentPolicy, errMsg => errs += errMsg) ), - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "``", "colName" -> "`t`.`a`.`element`", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index e6d915903f9bc..794112db5502a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -65,7 +65,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper checkError( exception = intercept[SparkArithmeticException](Decimal(170L, 2, 1)), - errorClass = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION", + condition = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION", parameters = Map( "value" -> "0", "precision" -> "2", @@ -73,7 +73,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper "config" -> "\"spark.sql.ansi.enabled\"")) checkError( exception = intercept[SparkArithmeticException](Decimal(170L, 2, 0)), - errorClass = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION", + condition = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION", parameters = Map( "value" -> "0", "precision" -> "2", @@ -81,7 +81,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper "config" -> "\"spark.sql.ansi.enabled\"")) checkError( exception = intercept[SparkArithmeticException](Decimal(BigDecimal("10.030"), 2, 1)), - errorClass = "NUMERIC_VALUE_OUT_OF_RANGE.WITHOUT_SUGGESTION", + condition = "NUMERIC_VALUE_OUT_OF_RANGE.WITHOUT_SUGGESTION", parameters = Map( "roundedValue" -> "10.0", "originalValue" -> "10.030", @@ -89,7 +89,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper "scale" -> "1")) checkError( exception = intercept[SparkArithmeticException](Decimal(BigDecimal("-9.95"), 2, 1)), - errorClass = "NUMERIC_VALUE_OUT_OF_RANGE.WITHOUT_SUGGESTION", + condition = "NUMERIC_VALUE_OUT_OF_RANGE.WITHOUT_SUGGESTION", parameters = Map( "roundedValue" -> "-10.0", "originalValue" -> "-9.95", @@ -97,7 +97,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper "scale" -> "1")) checkError( exception = intercept[SparkArithmeticException](Decimal(1e17.toLong, 17, 0)), - errorClass = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION", + condition = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION", parameters = Map( "value" -> "0", "precision" -> "17", @@ -120,7 +120,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper def checkNegativeScaleDecimal(d: => Decimal): Unit = { checkError( exception = intercept[SparkException] (d), - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> ("Negative scale is not allowed: -3. " + "Set the config \"spark.sql.legacy.allowNegativeScaleOfDecimal\" " + "to \"true\" to allow it.")) @@ -317,7 +317,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper exception = intercept[SparkException] { d.toPrecision(5, 50, BigDecimal.RoundingMode.HALF_DOWN) }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "Not supported rounding mode: HALF_DOWN.") ) } @@ -350,7 +350,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper checkError( exception = intercept[SparkArithmeticException]( Decimal.fromStringANSI(UTF8String.fromString(string))), - errorClass = "NUMERIC_OUT_OF_SUPPORTED_RANGE", + condition = "NUMERIC_OUT_OF_SUPPORTED_RANGE", parameters = Map("value" -> string)) } @@ -370,12 +370,11 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper checkError( exception = intercept[SparkNumberFormatException]( Decimal.fromStringANSI(UTF8String.fromString("str"))), - errorClass = "CAST_INVALID_INPUT", + condition = "CAST_INVALID_INPUT", parameters = Map( "expression" -> "'str'", "sourceType" -> "\"STRING\"", - "targetType" -> "\"DECIMAL(10,0)\"", - "ansiConfig" -> "\"spark.sql.ansi.enabled\"")) + "targetType" -> "\"DECIMAL(10,0)\"")) } test("SPARK-35841: Casting string to decimal type doesn't work " + @@ -398,7 +397,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper checkError( exception = intercept[SparkArithmeticException]( Decimal.fromStringANSI(UTF8String.fromString(string))), - errorClass = "NUMERIC_OUT_OF_SUPPORTED_RANGE", + condition = "NUMERIC_OUT_OF_SUPPORTED_RANGE", parameters = Map("value" -> string)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index 562febe381130..6a67525dd02d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -45,21 +45,21 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { test("lookup a single missing field should output existing fields") { checkError( exception = intercept[SparkIllegalArgumentException](s("c")), - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`c`", "fields" -> "`a`, `b`")) } test("lookup a set of missing fields should output existing fields") { checkError( exception = intercept[SparkIllegalArgumentException](s(Set("a", "c"))), - errorClass = "NONEXISTENT_FIELD_NAME_IN_LIST", + condition = "NONEXISTENT_FIELD_NAME_IN_LIST", parameters = Map("nonExistFields" -> "`c`", "fieldNames" -> "`a`, `b`")) } test("lookup fieldIndex for missing field should output existing fields") { checkError( exception = intercept[SparkIllegalArgumentException](s.fieldIndex("c")), - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`c`", "fields" -> "`a`, `b`")) } @@ -341,7 +341,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { } checkError( exception = e, - errorClass = "INVALID_FIELD_NAME", + condition = "INVALID_FIELD_NAME", parameters = Map( "fieldName" -> "`S1`.`S12`.`S123`", "path" -> "`s1`.`s12`")) @@ -352,7 +352,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { } checkError( exception = e, - errorClass = "AMBIGUOUS_COLUMN_OR_FIELD", + condition = "AMBIGUOUS_COLUMN_OR_FIELD", parameters = Map("name" -> "`S2`.`x`", "n" -> "2")) caseSensitiveCheck(Seq("s2", "x"), Some(Seq("s2") -> StructField("x", IntegerType))) @@ -362,7 +362,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { } checkError( exception = e, - errorClass = "INVALID_FIELD_NAME", + condition = "INVALID_FIELD_NAME", parameters = Map( "fieldName" -> "`m1`.`key`", "path" -> "`m1`")) @@ -373,7 +373,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { } checkError( exception = e, - errorClass = "INVALID_FIELD_NAME", + condition = "INVALID_FIELD_NAME", parameters = Map( "fieldName" -> "`M1`.`key`.`name`", "path" -> "`m1`.`key`")) @@ -382,7 +382,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { } checkError( exception = e, - errorClass = "INVALID_FIELD_NAME", + condition = "INVALID_FIELD_NAME", parameters = Map( "fieldName" -> "`M1`.`value`.`name`", "path" -> "`m1`.`value`")) @@ -399,7 +399,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { } checkError( exception = e, - errorClass = "INVALID_FIELD_NAME", + condition = "INVALID_FIELD_NAME", parameters = Map( "fieldName" -> "`m2`.`key`.`A`.`name`", "path" -> "`m2`.`key`.`a`")) @@ -408,7 +408,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { } checkError( exception = e, - errorClass = "INVALID_FIELD_NAME", + condition = "INVALID_FIELD_NAME", parameters = Map( "fieldName" -> "`M2`.`value`.`b`.`name`", "path" -> "`m2`.`value`.`b`")) @@ -418,7 +418,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { } checkError( exception = e, - errorClass = "INVALID_FIELD_NAME", + condition = "INVALID_FIELD_NAME", parameters = Map( "fieldName" -> "`A1`.`element`", "path" -> "`a1`")) @@ -428,7 +428,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { } checkError( exception = e, - errorClass = "INVALID_FIELD_NAME", + condition = "INVALID_FIELD_NAME", parameters = Map( "fieldName" -> "`A1`.`element`.`name`", "path" -> "`a1`.`element`")) @@ -442,7 +442,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { } checkError( exception = e, - errorClass = "INVALID_FIELD_NAME", + condition = "INVALID_FIELD_NAME", parameters = Map( "fieldName" -> "`a2`.`element`.`C`.`name`", "path" -> "`a2`.`element`.`c`")) @@ -456,7 +456,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { } checkError( exception = e, - errorClass = "INVALID_FIELD_NAME", + condition = "INVALID_FIELD_NAME", parameters = Map( "fieldName" -> "`M3`.`value`.`value`.`MA`.`name`", "path" -> "`m3`.`value`.`value`.`ma`")) @@ -470,7 +470,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { } checkError( exception = e, - errorClass = "INVALID_FIELD_NAME", + condition = "INVALID_FIELD_NAME", parameters = Map( "fieldName" -> "`A3`.`element`.`element`.`D`.`name`", "path" -> "`a3`.`element`.`element`.`d`") @@ -522,7 +522,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkException] { StructType.fromDDL("c1 DECIMAL(10, 5)").merge(StructType.fromDDL("c1 DECIMAL(12, 2)")) }, - errorClass = "CANNOT_MERGE_INCOMPATIBLE_DATA_TYPE", + condition = "CANNOT_MERGE_INCOMPATIBLE_DATA_TYPE", parameters = Map("left" -> "\"DECIMAL(10,5)\"", "right" -> "\"DECIMAL(12,2)\"") ) @@ -530,7 +530,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { exception = intercept[SparkException] { StructType.fromDDL("c1 DECIMAL(12, 5)").merge(StructType.fromDDL("c1 DECIMAL(12, 2)")) }, - errorClass = "CANNOT_MERGE_INCOMPATIBLE_DATA_TYPE", + condition = "CANNOT_MERGE_INCOMPATIBLE_DATA_TYPE", parameters = Map("left" -> "\"DECIMAL(12,5)\"", "right" -> "\"DECIMAL(12,2)\"") ) } @@ -564,7 +564,6 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { .putString(ResolveDefaultColumns.EXISTS_DEFAULT_COLUMN_METADATA_KEY, "1 + 1") .putString(ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY, "1 + 1") .build()))) - val error = "fails to parse as a valid literal value" assert(ResolveDefaultColumns.existenceDefaultValues(source2).length == 1) assert(ResolveDefaultColumns.existenceDefaultValues(source2)(0) == 2) @@ -576,9 +575,13 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { .putString(ResolveDefaultColumns.EXISTS_DEFAULT_COLUMN_METADATA_KEY, "invalid") .putString(ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY, "invalid") .build()))) - assert(intercept[AnalysisException] { - ResolveDefaultColumns.existenceDefaultValues(source3) - }.getMessage.contains(error)) + + checkError( + exception = intercept[AnalysisException]{ + ResolveDefaultColumns.existenceDefaultValues(source3) + }, + condition = "INVALID_DEFAULT_VALUE.UNRESOLVED_EXPRESSION", + parameters = Map("statement" -> "", "colName" -> "`c1`", "defaultValue" -> "invalid")) // Negative test: StructType.defaultValues fails because the existence default value fails to // resolve. @@ -592,9 +595,15 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY, "(SELECT 'abc' FROM missingtable)") .build()))) - assert(intercept[AnalysisException] { - ResolveDefaultColumns.existenceDefaultValues(source4) - }.getMessage.contains(error)) + + checkError( + exception = intercept[AnalysisException]{ + ResolveDefaultColumns.existenceDefaultValues(source4) + }, + condition = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", + parameters = Map("statement" -> "", + "colName" -> "`c1`", + "defaultValue" -> "(SELECT 'abc' FROM missingtable)")) } test("SPARK-46629: Test STRUCT DDL with NOT NULL round trip") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala index c0fa43ff9bde0..c705a6b791bd1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala @@ -54,13 +54,13 @@ class ArrowUtilsSuite extends SparkFunSuite { exception = intercept[SparkException] { roundtrip(TimestampType) }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "Missing timezoneId where it is mandatory.")) checkError( exception = intercept[SparkUnsupportedOperationException] { ArrowUtils.fromArrowType(new ArrowType.Int(8, false)) }, - errorClass = "UNSUPPORTED_ARROWTYPE", + condition = "UNSUPPORTED_ARROWTYPE", parameters = Map("typeName" -> "Int(8, false)") ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.scala index 98c2a3d1e2726..932fb0a733371 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.scala @@ -67,7 +67,7 @@ class CaseInsensitiveStringMapSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { options.getBoolean("FOO", true) }, - errorClass = "_LEGACY_ERROR_TEMP_3206", + condition = "_LEGACY_ERROR_TEMP_3206", parameters = Map("value" -> "bar")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala index c5f19b438f27f..a277bb021c3f6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala @@ -46,20 +46,20 @@ class SchemaUtilsSuite extends SparkFunSuite { exception = intercept[AnalysisException] { SchemaUtils.checkSchemaColumnNameDuplication(schema, caseSensitive) }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`a`")) checkError( exception = intercept[AnalysisException] { SchemaUtils.checkColumnNameDuplication(schema.map(_.name), resolver(caseSensitive)) }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`a`")) checkError( exception = intercept[AnalysisException] { SchemaUtils.checkColumnNameDuplication( schema.map(_.name), caseSensitiveAnalysis = caseSensitive) }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`a`")) } @@ -106,7 +106,7 @@ class SchemaUtilsSuite extends SparkFunSuite { exception = intercept[AnalysisException] { SchemaUtils.checkSchemaColumnNameDuplication(schema) }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`camelcase`")) } } diff --git a/sql/connect/common/src/main/protobuf/spark/connect/relations.proto b/sql/connect/common/src/main/protobuf/spark/connect/relations.proto index 04fe21086097c..1003e5c21d639 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -76,6 +76,7 @@ message Relation { AsOfJoin as_of_join = 39; CommonInlineUserDefinedDataSource common_inline_user_defined_data_source = 40; WithRelations with_relations = 41; + Transpose transpose = 42; // NA functions NAFill fill_na = 90; @@ -889,6 +890,18 @@ message Unpivot { } } +// Transpose a DataFrame, switching rows to columns. +// Transforms the DataFrame such that the values in the specified index column +// become the new columns of the DataFrame. +message Transpose { + // (Required) The input relation. + Relation input = 1; + + // (Optional) A list of columns that will be treated as the indices. + // Only single column is supported now. + repeated Expression index_columns = 2; +} + message ToSchema { // (Required) The input relation. Relation input = 1; diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala index d77b4b820c090..2dba8fc3b7778 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala @@ -28,6 +28,9 @@ import org.apache.spark.sql.streaming.GroupState * mapPartitions etc. This class is shared between the client and the server so that when the * methods are used in client UDFs, the server will be able to find them when actually executing * the UDFs. + * + * DO NOT REMOVE/CHANGE THIS OBJECT OR ANY OF ITS METHODS, THEY ARE NEEDED FOR BACKWARDS + * COMPATIBILITY WITH OLDER CLIENTS! */ @SerialVersionUID(8464839273647598302L) private[sql] object UdfUtils extends Serializable { @@ -137,8 +140,6 @@ private[sql] object UdfUtils extends Serializable { // ---------------------------------------------------------------------------------------------- // Scala Functions wrappers for java UDFs. - // - // DO NOT REMOVE THESE, THEY ARE NEEDED FOR BACKWARDS COMPATIBILITY WITH OLDER CLIENTS! // ---------------------------------------------------------------------------------------------- // (1 to 22).foreach { i => // val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath.explain index 6eb4805b4fcc4..61d81bad95c65 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath.explain @@ -1,2 +1,2 @@ -Project [from_protobuf(bytes#0, StorageLevel, Some([B)) AS from_protobuf(bytes)#0] +Project [from_protobuf(bytes#0, StorageLevel, Some([B)) AS from_protobuf(bytes, StorageLevel, X'0AFC010A0C636F6D6D6F6E2E70726F746F120D737061726B2E636F6E6E65637422B0010A0C53746F726167654C6576656C12190A087573655F6469736B18012001280852077573654469736B121D0A0A7573655F6D656D6F727918022001280852097573654D656D6F727912200A0C7573655F6F66665F68656170180320012808520A7573654F66664865617012220A0C646573657269616C697A6564180420012808520C646573657269616C697A656412200A0B7265706C69636174696F6E180520012805520B7265706C69636174696F6E42220A1E6F72672E6170616368652E737061726B2E636F6E6E6563742E70726F746F5001620670726F746F33', NULL)#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath_options.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath_options.explain index c4a47b1aef07b..066dba527a09a 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath_options.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_descFilePath_options.explain @@ -1,2 +1,2 @@ -Project [from_protobuf(bytes#0, StorageLevel, Some([B), (recursive.fields.max.depth,2)) AS from_protobuf(bytes)#0] +Project [from_protobuf(bytes#0, StorageLevel, Some([B), (recursive.fields.max.depth,2)) AS from_protobuf(bytes, StorageLevel, X'0AFC010A0C636F6D6D6F6E2E70726F746F120D737061726B2E636F6E6E65637422B0010A0C53746F726167654C6576656C12190A087573655F6469736B18012001280852077573654469736B121D0A0A7573655F6D656D6F727918022001280852097573654D656D6F727912200A0C7573655F6F66665F68656170180320012808520A7573654F66664865617012220A0C646573657269616C697A6564180420012808520C646573657269616C697A656412200A0B7265706C69636174696F6E180520012805520B7265706C69636174696F6E42220A1E6F72672E6170616368652E737061726B2E636F6E6E6563742E70726F746F5001620670726F746F33', map(recursive.fields.max.depth, 2))#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json.explain index 1219f11d4696e..8d1d122d156ff 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json.explain @@ -1,2 +1,2 @@ -Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles)) AS from_json(g)#0] +Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles), false) AS from_json(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_orphaned.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_orphaned.explain index 1219f11d4696e..8d1d122d156ff 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_orphaned.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_orphaned.explain @@ -1,2 +1,2 @@ -Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles)) AS from_json(g)#0] +Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles), false) AS from_json(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_with_json_schema.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_with_json_schema.explain index 1219f11d4696e..8d1d122d156ff 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_with_json_schema.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_with_json_schema.explain @@ -1,2 +1,2 @@ -Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles)) AS from_json(g)#0] +Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles), false) AS from_json(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_sentences.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_sentences.explain index 5c88a1f7b3abd..f4532e70675ae 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_sentences.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_sentences.explain @@ -1,2 +1,2 @@ -Project [sentences(g#0, , ) AS sentences(g, , )#0] +Project [static_invoke(ExpressionImplUtils.getSentences(g#0, , )) AS sentences(g, , )#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_sentences_with_language.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_sentences_with_language.explain new file mode 100644 index 0000000000000..37bcbf9a319b5 --- /dev/null +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_sentences_with_language.explain @@ -0,0 +1,2 @@ +Project [static_invoke(ExpressionImplUtils.getSentences(g#0, en, )) AS sentences(g, en, )#0] ++- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_sentences_with_language_and_country.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_sentences_with_language_and_country.explain new file mode 100644 index 0000000000000..8a8d54cfa0d10 --- /dev/null +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_sentences_with_language_and_country.explain @@ -0,0 +1,2 @@ +Project [static_invoke(ExpressionImplUtils.getSentences(g#0, en, US)) AS sentences(g, en, US)#0] ++- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_sentences_with_locale.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_sentences_with_locale.explain deleted file mode 100644 index 7819f9b542340..0000000000000 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_sentences_with_locale.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [sentences(g#0, en, US) AS sentences(g, en, US)#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/melt_no_values.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/melt_no_values.explain index f61fc30a3a529..053937d84ec8f 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/melt_no_values.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/melt_no_values.explain @@ -1,2 +1,2 @@ -Expand [[id#0L, a#0, b, b#0]], [id#0L, a#0, #0, value#0] +Expand [[id#0L, a#0, b, b#0]], [id#0L, a#0, name#0, value#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/melt_values.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/melt_values.explain index b5742d976dee9..5a953f792cd35 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/melt_values.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/melt_values.explain @@ -1,2 +1,2 @@ -Expand [[a#0, id, id#0L]], [a#0, #0, value#0L] +Expand [[a#0, id, id#0L]], [a#0, name#0, value#0L] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName.explain index e7f70fa2c1a9e..3533406c0bf0a 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName.explain @@ -1,2 +1,2 @@ -Project [to_protobuf(bytes#0, org.apache.spark.connect.proto.StorageLevel, None) AS to_protobuf(bytes)#0] +Project [to_protobuf(bytes#0, org.apache.spark.connect.proto.StorageLevel, None) AS to_protobuf(bytes, org.apache.spark.connect.proto.StorageLevel, NULL, NULL)#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath.explain index 7c688cc446947..f6a33a20a5dcd 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath.explain @@ -1,2 +1,2 @@ -Project [to_protobuf(bytes#0, StorageLevel, Some([B)) AS to_protobuf(bytes)#0] +Project [to_protobuf(bytes#0, StorageLevel, Some([B)) AS to_protobuf(bytes, StorageLevel, X'0AFC010A0C636F6D6D6F6E2E70726F746F120D737061726B2E636F6E6E65637422B0010A0C53746F726167654C6576656C12190A087573655F6469736B18012001280852077573654469736B121D0A0A7573655F6D656D6F727918022001280852097573654D656D6F727912200A0C7573655F6F66665F68656170180320012808520A7573654F66664865617012220A0C646573657269616C697A6564180420012808520C646573657269616C697A656412200A0B7265706C69636174696F6E180520012805520B7265706C69636174696F6E42220A1E6F72672E6170616368652E737061726B2E636F6E6E6563742E70726F746F5001620670726F746F33', NULL)#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath_options.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath_options.explain index 9f05bb03c9c6d..393529a15670d 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath_options.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_descFilePath_options.explain @@ -1,2 +1,2 @@ -Project [to_protobuf(bytes#0, StorageLevel, Some([B), (recursive.fields.max.depth,2)) AS to_protobuf(bytes)#0] +Project [to_protobuf(bytes#0, StorageLevel, Some([B), (recursive.fields.max.depth,2)) AS to_protobuf(bytes, StorageLevel, X'0AFC010A0C636F6D6D6F6E2E70726F746F120D737061726B2E636F6E6E65637422B0010A0C53746F726167654C6576656C12190A087573655F6469736B18012001280852077573654469736B121D0A0A7573655F6D656D6F727918022001280852097573654D656D6F727912200A0C7573655F6F66665F68656170180320012808520A7573654F66664865617012220A0C646573657269616C697A6564180420012808520C646573657269616C697A656412200A0B7265706C69636174696F6E180520012805520B7265706C69636174696F6E42220A1E6F72672E6170616368652E737061726B2E636F6E6E6563742E70726F746F5001620670726F746F33', map(recursive.fields.max.depth, 2))#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_options.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_options.explain index a5d8851a7d1f3..e0c7e1625fe5c 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_options.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/to_protobuf_messageClassName_options.explain @@ -1,2 +1,2 @@ -Project [to_protobuf(bytes#0, org.apache.spark.connect.proto.StorageLevel, None, (recursive.fields.max.depth,2)) AS to_protobuf(bytes)#0] +Project [to_protobuf(bytes#0, org.apache.spark.connect.proto.StorageLevel, None, (recursive.fields.max.depth,2)) AS to_protobuf(bytes, org.apache.spark.connect.proto.StorageLevel, NULL, map(recursive.fields.max.depth, 2))#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/unpivot_no_values.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/unpivot_no_values.explain index 8d1749ee74c5a..2b2ba19d0c3db 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/unpivot_no_values.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/unpivot_no_values.explain @@ -1,2 +1,2 @@ -Expand [[id#0L, a, cast(a#0 as double)], [id#0L, b, b#0]], [id#0L, #0, value#0] +Expand [[id#0L, a, cast(a#0 as double)], [id#0L, b, b#0]], [id#0L, name#0, value#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/unpivot_values.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/unpivot_values.explain index f61fc30a3a529..053937d84ec8f 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/unpivot_values.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/unpivot_values.explain @@ -1,2 +1,2 @@ -Expand [[id#0L, a#0, b, b#0]], [id#0L, a#0, #0, value#0] +Expand [[id#0L, a#0, b, b#0]], [id#0L, a#0, name#0, value#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_sentences_with_language.json b/sql/connect/common/src/test/resources/query-tests/queries/function_sentences_with_language.json new file mode 100644 index 0000000000000..869e074ccd604 --- /dev/null +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_sentences_with_language.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "sentences", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "g" + } + }, { + "literal": { + "string": "en" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_sentences_with_language.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_sentences_with_language.proto.bin new file mode 100644 index 0000000000000..7514b380a1c82 Binary files /dev/null and b/sql/connect/common/src/test/resources/query-tests/queries/function_sentences_with_language.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_sentences_with_locale.json b/sql/connect/common/src/test/resources/query-tests/queries/function_sentences_with_language_and_country.json similarity index 100% rename from sql/connect/common/src/test/resources/query-tests/queries/function_sentences_with_locale.json rename to sql/connect/common/src/test/resources/query-tests/queries/function_sentences_with_language_and_country.json diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_sentences_with_locale.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_sentences_with_language_and_country.proto.bin similarity index 100% rename from sql/connect/common/src/test/resources/query-tests/queries/function_sentences_with_locale.proto.bin rename to sql/connect/common/src/test/resources/query-tests/queries/function_sentences_with_language_and_country.proto.bin diff --git a/sql/connect/common/src/test/resources/query-tests/queries/melt_no_values.json b/sql/connect/common/src/test/resources/query-tests/queries/melt_no_values.json index 12db0a5abe368..a17da06b925b9 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/melt_no_values.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/melt_no_values.json @@ -20,6 +20,7 @@ "unparsedIdentifier": "a" } }], + "variableColumnName": "name", "valueColumnName": "value" } } \ No newline at end of file diff --git a/sql/connect/common/src/test/resources/query-tests/queries/melt_no_values.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/melt_no_values.proto.bin index 23a6aa1289a99..eebb7ad6df8e2 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/melt_no_values.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/melt_no_values.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/melt_values.json b/sql/connect/common/src/test/resources/query-tests/queries/melt_values.json index e2a004f46e781..a8142ee3a8461 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/melt_values.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/melt_values.json @@ -23,6 +23,7 @@ } }] }, + "variableColumnName": "name", "valueColumnName": "value" } } \ No newline at end of file diff --git a/sql/connect/common/src/test/resources/query-tests/queries/melt_values.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/melt_values.proto.bin index e021e1110def5..35829fc62dae9 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/melt_values.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/melt_values.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/transpose_index_column.json b/sql/connect/common/src/test/resources/query-tests/queries/transpose_index_column.json new file mode 100644 index 0000000000000..19a2086c8d7df --- /dev/null +++ b/sql/connect/common/src/test/resources/query-tests/queries/transpose_index_column.json @@ -0,0 +1,20 @@ +{ + "common": { + "planId": "1" + }, + "transpose": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double\u003e" + } + }, + "indexColumns": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "id" + } + }] + } +} \ No newline at end of file diff --git a/sql/connect/common/src/test/resources/query-tests/queries/transpose_index_column.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/transpose_index_column.proto.bin new file mode 100644 index 0000000000000..8590932d95cb4 Binary files /dev/null and b/sql/connect/common/src/test/resources/query-tests/queries/transpose_index_column.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/transpose_no_index_column.json b/sql/connect/common/src/test/resources/query-tests/queries/transpose_no_index_column.json new file mode 100644 index 0000000000000..82b0448c373e1 --- /dev/null +++ b/sql/connect/common/src/test/resources/query-tests/queries/transpose_no_index_column.json @@ -0,0 +1,15 @@ +{ + "common": { + "planId": "1" + }, + "transpose": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double\u003e" + } + } + } +} \ No newline at end of file diff --git a/sql/connect/common/src/test/resources/query-tests/queries/transpose_no_index_column.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/transpose_no_index_column.proto.bin new file mode 100644 index 0000000000000..c1ea985a64a4b Binary files /dev/null and b/sql/connect/common/src/test/resources/query-tests/queries/transpose_no_index_column.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/unpivot_no_values.json b/sql/connect/common/src/test/resources/query-tests/queries/unpivot_no_values.json index 9f550c0319147..96b76443b6790 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/unpivot_no_values.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/unpivot_no_values.json @@ -16,6 +16,7 @@ "unparsedIdentifier": "id" } }], + "variableColumnName": "name", "valueColumnName": "value" } } \ No newline at end of file diff --git a/sql/connect/common/src/test/resources/query-tests/queries/unpivot_no_values.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/unpivot_no_values.proto.bin index ac3bad8bd04ed..b700190a9f667 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/unpivot_no_values.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/unpivot_no_values.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/unpivot_values.json b/sql/connect/common/src/test/resources/query-tests/queries/unpivot_values.json index 92bc19d195c6e..6c31afb04e741 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/unpivot_values.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/unpivot_values.json @@ -27,6 +27,7 @@ } }] }, + "variableColumnName": "name", "valueColumnName": "value" } } \ No newline at end of file diff --git a/sql/connect/common/src/test/resources/query-tests/queries/unpivot_values.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/unpivot_values.proto.bin index 7f717cb23517b..a1cd388fd8a46 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/unpivot_values.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/unpivot_values.proto.bin differ diff --git a/sql/connect/server/pom.xml b/sql/connect/server/pom.xml index 2e7c3f81aa1dc..3350c4261e9da 100644 --- a/sql/connect/server/pom.xml +++ b/sql/connect/server/pom.xml @@ -111,7 +111,7 @@ org.apache.spark spark-protobuf_${scala.binary.version} ${project.version} - provided + test org.apache.spark diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 92709ff29a1ca..b64637f7d2472 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -63,7 +63,7 @@ object Connect { "conservatively use 70% of it because the size is not accurate but estimated.") .version("3.4.0") .bytesConf(ByteUnit.BYTE) - .createWithDefault(4 * 1024 * 1024) + .createWithDefault(ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE) val CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE = buildStaticConf("spark.connect.grpc.maxInboundMessageSize") diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala index a1881765a416c..72c77fd033d76 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala @@ -144,7 +144,9 @@ private[connect] class ConnectProgressExecutionListener extends SparkListener wi tracker.stages.get(taskEnd.stageId).foreach { stage => stage.update { i => i.completedTasks += 1 - i.inputBytesRead += taskEnd.taskMetrics.inputMetrics.bytesRead + i.inputBytesRead += Option(taskEnd.taskMetrics) + .map(_.inputMetrics.bytesRead) + .getOrElse(0L) } } // This should never become negative, simply reset to zero if it does. diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala index 3e360372d5600..051093fcad277 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala @@ -142,7 +142,9 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( * client, but rather enqueued to in the response observer. */ private def enqueueProgressMessage(force: Boolean = false): Unit = { - if (executeHolder.sessionHolder.session.conf.get(CONNECT_PROGRESS_REPORT_INTERVAL) > 0) { + val progressReportInterval = executeHolder.sessionHolder.session.sessionState.conf + .getConf(CONNECT_PROGRESS_REPORT_INTERVAL) + if (progressReportInterval > 0) { SparkConnectService.executionListener.foreach { listener => // It is possible, that the tracker is no longer available and in this // case we simply ignore it and do not send any progress message. This avoids @@ -240,8 +242,8 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( // monitor, and will notify upon state change. if (response.isEmpty) { // Wake up more frequently to send the progress updates. - val progressTimeout = - executeHolder.sessionHolder.session.conf.get(CONNECT_PROGRESS_REPORT_INTERVAL) + val progressTimeout = executeHolder.sessionHolder.session.sessionState.conf + .getConf(CONNECT_PROGRESS_REPORT_INTERVAL) // If the progress feature is disabled, wait for the deadline. val timeout = if (progressTimeout > 0) { progressTimeout diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 58e61badaf370..33c9edb1cd21a 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -40,14 +40,14 @@ import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult import org.apache.spark.connect.proto.Parse.ParseFormat import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance import org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase -import org.apache.spark.internal.{Logging, MDC} +import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID} import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest} -import org.apache.spark.sql.{Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession} +import org.apache.spark.sql.{Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, Row, SparkSession} import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker} -import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar, UnresolvedTranspose} +import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -78,9 +78,8 @@ import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} -import org.apache.spark.sql.internal.{CatalogImpl, TypedAggUtils} +import org.apache.spark.sql.internal.{CatalogImpl, MergeIntoWriterImpl, TypedAggUtils} import org.apache.spark.sql.internal.ExpressionUtils.column -import org.apache.spark.sql.protobuf.{CatalystDataToProtobuf, ProtobufDataToCatalyst} import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -204,6 +203,7 @@ class SparkConnectPlanner( transformCachedLocalRelation(rel.getCachedLocalRelation) case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint) case proto.Relation.RelTypeCase.UNPIVOT => transformUnpivot(rel.getUnpivot) + case proto.Relation.RelTypeCase.TRANSPOSE => transformTranspose(rel.getTranspose) case proto.Relation.RelTypeCase.REPARTITION_BY_EXPRESSION => transformRepartitionByExpression(rel.getRepartitionByExpression) case proto.Relation.RelTypeCase.MAP_PARTITIONS => @@ -1126,6 +1126,13 @@ class SparkConnectPlanner( UnresolvedHint(rel.getName, params, transformRelation(rel.getInput)) } + private def transformTranspose(rel: proto.Transpose): LogicalPlan = { + val child = transformRelation(rel.getInput) + val indices = rel.getIndexColumnsList.asScala.map(transformExpression).toSeq + + UnresolvedTranspose(indices = indices, child = child) + } + private def transformUnpivot(rel: proto.Unpivot): LogicalPlan = { val ids = rel.getIdsList.asScala.toArray.map { expr => column(transformExpression(expr)) @@ -1863,67 +1870,15 @@ class SparkConnectPlanner( } Some(CatalystDataToAvro(children.head, jsonFormatSchema)) - // Protobuf-specific functions - case "from_protobuf" if Seq(2, 3, 4).contains(fun.getArgumentsCount) => - val children = fun.getArgumentsList.asScala.map(transformExpression) - val (msgName, desc, options) = extractProtobufArgs(children.toSeq) - Some(ProtobufDataToCatalyst(children(0), msgName, desc, options)) - - case "to_protobuf" if Seq(2, 3, 4).contains(fun.getArgumentsCount) => - val children = fun.getArgumentsList.asScala.map(transformExpression) - val (msgName, desc, options) = extractProtobufArgs(children.toSeq) - Some(CatalystDataToProtobuf(children(0), msgName, desc, options)) - case _ => None } } - private def extractProtobufArgs(children: Seq[Expression]) = { - val msgName = extractString(children(1), "MessageClassName") - var desc = Option.empty[Array[Byte]] - var options = Map.empty[String, String] - if (children.length == 3) { - children(2) match { - case b: Literal => desc = Some(extractBinary(b, "binaryFileDescriptorSet")) - case o => options = extractMapData(o, "options") - } - } else if (children.length == 4) { - desc = Some(extractBinary(children(2), "binaryFileDescriptorSet")) - options = extractMapData(children(3), "options") - } - (msgName, desc, options) - } - - private def extractBoolean(expr: Expression, field: String): Boolean = expr match { - case Literal(bool: Boolean, BooleanType) => bool - case other => throw InvalidPlanInput(s"$field should be a literal boolean, but got $other") - } - - private def extractDouble(expr: Expression, field: String): Double = expr match { - case Literal(double: Double, DoubleType) => double - case other => throw InvalidPlanInput(s"$field should be a literal double, but got $other") - } - - private def extractInteger(expr: Expression, field: String): Int = expr match { - case Literal(int: Int, IntegerType) => int - case other => throw InvalidPlanInput(s"$field should be a literal integer, but got $other") - } - - private def extractLong(expr: Expression, field: String): Long = expr match { - case Literal(long: Long, LongType) => long - case other => throw InvalidPlanInput(s"$field should be a literal long, but got $other") - } - private def extractString(expr: Expression, field: String): String = expr match { case Literal(s, StringType) if s != null => s.toString case other => throw InvalidPlanInput(s"$field should be a literal string, but got $other") } - private def extractBinary(expr: Expression, field: String): Array[Byte] = expr match { - case Literal(b: Array[Byte], BinaryType) if b != null => b - case other => throw InvalidPlanInput(s"$field should be a literal binary, but got $other") - } - @scala.annotation.tailrec private def extractMapData(expr: Expression, field: String): Map[String, String] = expr match { case map: CreateMap => ExprUtils.convertToMapData(map) @@ -1932,23 +1887,6 @@ class SparkConnectPlanner( case other => throw InvalidPlanInput(s"$field should be created by map, but got $other") } - // Extract the schema from a literal string representing a JSON-formatted schema - private def extractDataTypeFromJSON(exp: proto.Expression): Option[DataType] = { - exp.getExprTypeCase match { - case proto.Expression.ExprTypeCase.LITERAL => - exp.getLiteral.getLiteralTypeCase match { - case proto.Expression.Literal.LiteralTypeCase.STRING => - try { - Some(DataType.fromJson(exp.getLiteral.getString)) - } catch { - case _: Exception => None - } - case _ => None - } - case _ => None - } - } - private def transformAlias(alias: proto.Expression.Alias): NamedExpression = { if (alias.getNameCount == 1) { val metadata = if (alias.hasMetadata() && alias.getMetadata.nonEmpty) { @@ -2380,16 +2318,17 @@ class SparkConnectPlanner( if (fun.getArgumentsCount != 1) { throw InvalidPlanInput("reduce requires single child expression") } - val udf = fun.getArgumentsList.asScala.map(transformExpression) match { - case collection.Seq(f: ScalaUDF) => - f + val udf = fun.getArgumentsList.asScala match { + case collection.Seq(e) + if e.hasCommonInlineUserDefinedFunction && + e.getCommonInlineUserDefinedFunction.hasScalarScalaUdf => + unpackUdf(e.getCommonInlineUserDefinedFunction) case other => throw InvalidPlanInput(s"reduce should carry a scalar scala udf, but got $other") } - assert(udf.outputEncoder.isDefined) - val tEncoder = udf.outputEncoder.get // (T, T) => T - val reduce = ReduceAggregator(udf.function)(tEncoder).toColumn.expr - TypedAggUtils.withInputType(reduce, tEncoder, dataAttributes) + val encoder = udf.outputEncoder + val reduce = ReduceAggregator(udf.function)(encoder).toColumn.expr + TypedAggUtils.withInputType(reduce, encoderFor(encoder), dataAttributes) } private def transformExpressionWithTypedReduceExpression( @@ -3114,10 +3053,13 @@ class SparkConnectPlanner( sessionHolder.streamingServersideListenerHolder.streamingQueryStartedEventCache.remove( query.runId.toString)) queryStartedEvent.foreach { - logDebug( - s"[SessionId: $sessionId][UserId: $userId][operationId: " + - s"${executeHolder.operationId}][query id: ${query.id}][query runId: ${query.runId}] " + - s"Adding QueryStartedEvent to response") + logInfo( + log"[SessionId: ${MDC(LogKeys.SESSION_ID, sessionId)}]" + + log"[UserId: ${MDC(LogKeys.USER_ID, userId)}] " + + log"[operationId: ${MDC(LogKeys.OPERATION_ID, executeHolder.operationId)}] " + + log"[query id: ${MDC(LogKeys.QUERY_ID, query.id)}]" + + log"[query runId: ${MDC(LogKeys.QUERY_RUN_ID, query.runId)}] " + + log"Adding QueryStartedEvent to response") e => resultBuilder.setQueryStartedEventJson(e.json) } @@ -3496,16 +3438,14 @@ class SparkConnectPlanner( val notMatchedBySourceActions = transformActions(cmd.getNotMatchedBySourceActionsList) val sourceDs = Dataset.ofRows(session, transformRelation(cmd.getSourceTablePlan)) - var mergeInto = sourceDs + val mergeInto = sourceDs .mergeInto(cmd.getTargetTableName, column(transformExpression(cmd.getMergeCondition))) - .withNewMatchedActions(matchedActions: _*) - .withNewNotMatchedActions(notMatchedActions: _*) - .withNewNotMatchedBySourceActions(notMatchedBySourceActions: _*) - - mergeInto = if (cmd.getWithSchemaEvolution) { + .asInstanceOf[MergeIntoWriterImpl[Row]] + mergeInto.matchedActions ++= matchedActions + mergeInto.notMatchedActions ++= notMatchedActions + mergeInto.notMatchedBySourceActions ++= notMatchedBySourceActions + if (cmd.getWithSchemaEvolution) { mergeInto.withSchemaEvolution() - } else { - mergeInto } mergeInto.merge() executeHolder.eventsManager.postFinished() diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index ec7ebbe92d72e..dc349c3e33251 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql.connect.service -import java.util.UUID - import scala.collection.mutable import scala.jdk.CollectionConverters._ -import org.apache.spark.{SparkEnv, SparkSQLException} +import org.apache.spark.SparkEnv import org.apache.spark.connect.proto import org.apache.spark.internal.Logging import org.apache.spark.sql.Observation @@ -35,30 +33,19 @@ import org.apache.spark.util.SystemClock * Object used to hold the Spark Connect execution state. */ private[connect] class ExecuteHolder( + val executeKey: ExecuteKey, val request: proto.ExecutePlanRequest, val sessionHolder: SessionHolder) extends Logging { val session = sessionHolder.session - val operationId = if (request.hasOperationId) { - try { - UUID.fromString(request.getOperationId).toString - } catch { - case _: IllegalArgumentException => - throw new SparkSQLException( - errorClass = "INVALID_HANDLE.FORMAT", - messageParameters = Map("handle" -> request.getOperationId)) - } - } else { - UUID.randomUUID().toString - } - /** * Tag that is set for this execution on SparkContext, via SparkContext.addJobTag. Used * (internally) for cancellation of the Spark Jobs ran by this execution. */ - val jobTag = ExecuteJobTag(sessionHolder.userId, sessionHolder.sessionId, operationId) + val jobTag = + ExecuteJobTag(sessionHolder.userId, sessionHolder.sessionId, executeKey.operationId) /** * Tags set by Spark Connect client users via SparkSession.addTag. Used to identify and group @@ -278,7 +265,7 @@ private[connect] class ExecuteHolder( request = request, userId = sessionHolder.userId, sessionId = sessionHolder.sessionId, - operationId = operationId, + operationId = executeKey.operationId, jobTag = jobTag, sparkSessionTags = sparkSessionTags, reattachable = reattachable, @@ -289,7 +276,10 @@ private[connect] class ExecuteHolder( } /** Get key used by SparkConnectExecutionManager global tracker. */ - def key: ExecuteKey = ExecuteKey(sessionHolder.userId, sessionHolder.sessionId, operationId) + def key: ExecuteKey = executeKey + + /** Get the operation ID. */ + def operationId: String = key.operationId } /** Used to identify ExecuteHolder jobTag among SparkContext.SPARK_JOB_TAGS. */ diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 0cb820b39e875..e56d66da3050d 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -444,8 +444,8 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio */ private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)( transform: proto.Relation => LogicalPlan): LogicalPlan = { - val planCacheEnabled = - Option(session).forall(_.conf.get(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true)) + val planCacheEnabled = Option(session) + .forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true)) // We only cache plans that have a plan ID. val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala index 1ab5f26f90b13..73a20e448be87 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala @@ -31,6 +31,15 @@ class SparkConnectExecutePlanHandler(responseObserver: StreamObserver[proto.Exec try { executeHolder.eventsManager.postStarted() executeHolder.start() + } catch { + // Errors raised before the execution holder has finished spawning a thread are considered + // plan execution failure, and the client should not try reattaching it afterwards. + case t: Throwable => + SparkConnectService.executionManager.removeExecuteHolder(executeHolder.key) + throw t + } + + try { val responseSender = new ExecuteGrpcResponseSender[proto.ExecutePlanResponse](executeHolder, responseObserver) executeHolder.runGrpcResponseSender(responseSender) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala index 6681a5f509c6e..d66964b8d34bd 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.connect.service -import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit} -import javax.annotation.concurrent.GuardedBy +import java.util.UUID +import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit} +import java.util.concurrent.atomic.{AtomicLong, AtomicReference} import scala.collection.mutable import scala.concurrent.duration.FiniteDuration @@ -36,6 +37,24 @@ import org.apache.spark.util.ThreadUtils // Unique key identifying execution by combination of user, session and operation id case class ExecuteKey(userId: String, sessionId: String, operationId: String) +object ExecuteKey { + def apply(request: proto.ExecutePlanRequest, sessionHolder: SessionHolder): ExecuteKey = { + val operationId = if (request.hasOperationId) { + try { + UUID.fromString(request.getOperationId).toString + } catch { + case _: IllegalArgumentException => + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.FORMAT", + messageParameters = Map("handle" -> request.getOperationId)) + } + } else { + UUID.randomUUID().toString + } + ExecuteKey(sessionHolder.userId, sessionHolder.sessionId, operationId) + } +} + /** * Global tracker of all ExecuteHolder executions. * @@ -44,11 +63,9 @@ case class ExecuteKey(userId: String, sessionId: String, operationId: String) */ private[connect] class SparkConnectExecutionManager() extends Logging { - /** Hash table containing all current executions. Guarded by executionsLock. */ - @GuardedBy("executionsLock") - private val executions: mutable.HashMap[ExecuteKey, ExecuteHolder] = - new mutable.HashMap[ExecuteKey, ExecuteHolder]() - private val executionsLock = new Object + /** Concurrent hash table containing all the current executions. */ + private val executions: ConcurrentMap[ExecuteKey, ExecuteHolder] = + new ConcurrentHashMap[ExecuteKey, ExecuteHolder]() /** Graveyard of tombstones of executions that were abandoned and removed. */ private val abandonedTombstones = CacheBuilder @@ -56,12 +73,12 @@ private[connect] class SparkConnectExecutionManager() extends Logging { .maximumSize(SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_ABANDONED_TOMBSTONES_SIZE)) .build[ExecuteKey, ExecuteInfo]() - /** None if there are no executions. Otherwise, the time when the last execution was removed. */ - @GuardedBy("executionsLock") - private var lastExecutionTimeMs: Option[Long] = Some(System.currentTimeMillis()) + /** The time when the last execution was removed. */ + private var lastExecutionTimeMs: AtomicLong = new AtomicLong(System.currentTimeMillis()) /** Executor for the periodic maintenance */ - private var scheduledExecutor: Option[ScheduledExecutorService] = None + private var scheduledExecutor: AtomicReference[ScheduledExecutorService] = + new AtomicReference[ScheduledExecutorService]() /** * Create a new ExecuteHolder and register it with this global manager and with its session. @@ -76,27 +93,30 @@ private[connect] class SparkConnectExecutionManager() extends Logging { request.getUserContext.getUserId, request.getSessionId, previousSessionId) - val executeHolder = new ExecuteHolder(request, sessionHolder) - executionsLock.synchronized { - // Check if the operation already exists, both in active executions, and in the graveyard - // of tombstones of executions that have been abandoned. - // The latter is to prevent double execution when a client retries execution, thinking it - // never reached the server, but in fact it did, and already got removed as abandoned. - if (executions.get(executeHolder.key).isDefined) { - throw new SparkSQLException( - errorClass = "INVALID_HANDLE.OPERATION_ALREADY_EXISTS", - messageParameters = Map("handle" -> executeHolder.operationId)) - } - if (getAbandonedTombstone(executeHolder.key).isDefined) { - throw new SparkSQLException( - errorClass = "INVALID_HANDLE.OPERATION_ABANDONED", - messageParameters = Map("handle" -> executeHolder.operationId)) - } - sessionHolder.addExecuteHolder(executeHolder) - executions.put(executeHolder.key, executeHolder) - lastExecutionTimeMs = None - logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, executeHolder.key)} is created.") - } + val executeKey = ExecuteKey(request, sessionHolder) + val executeHolder = executions.compute( + executeKey, + (executeKey, oldExecuteHolder) => { + // Check if the operation already exists, either in the active execution map, or in the + // graveyard of tombstones of executions that have been abandoned. The latter is to prevent + // double executions when the client retries, thinking it never reached the server, but in + // fact it did, and already got removed as abandoned. + if (oldExecuteHolder != null) { + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.OPERATION_ALREADY_EXISTS", + messageParameters = Map("handle" -> executeKey.operationId)) + } + if (getAbandonedTombstone(executeKey).isDefined) { + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.OPERATION_ABANDONED", + messageParameters = Map("handle" -> executeKey.operationId)) + } + new ExecuteHolder(executeKey, request, sessionHolder) + }) + + sessionHolder.addExecuteHolder(executeHolder) + + logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, executeHolder.key)} is created.") schedulePeriodicChecks() // Starts the maintenance thread if it hasn't started. @@ -108,43 +128,46 @@ private[connect] class SparkConnectExecutionManager() extends Logging { * execution if still running, free all resources. */ private[connect] def removeExecuteHolder(key: ExecuteKey, abandoned: Boolean = false): Unit = { - var executeHolder: Option[ExecuteHolder] = None - executionsLock.synchronized { - executeHolder = executions.remove(key) - executeHolder.foreach { e => - // Put into abandonedTombstones under lock, so that if it's accessed it will end up - // with INVALID_HANDLE.OPERATION_ABANDONED error. - if (abandoned) { - abandonedTombstones.put(key, e.getExecuteInfo) - } - e.sessionHolder.removeExecuteHolder(e.operationId) - } - if (executions.isEmpty) { - lastExecutionTimeMs = Some(System.currentTimeMillis()) - } - logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, key)} is removed.") + val executeHolder = executions.get(key) + + if (executeHolder == null) { + return } - // close the execution outside the lock - executeHolder.foreach { e => - e.close() - if (abandoned) { - // Update in abandonedTombstones: above it wasn't yet updated with closedTime etc. - abandonedTombstones.put(key, e.getExecuteInfo) - } + + // Put into abandonedTombstones before removing it from executions, so that the client ends up + // getting an INVALID_HANDLE.OPERATION_ABANDONED error on a retry. + if (abandoned) { + abandonedTombstones.put(key, executeHolder.getExecuteInfo) + } + + // Remove the execution from the map *after* putting it in abandonedTombstones. + executions.remove(key) + executeHolder.sessionHolder.removeExecuteHolder(executeHolder.operationId) + + updateLastExecutionTime() + + logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, key)} is removed.") + + executeHolder.close() + if (abandoned) { + // Update in abandonedTombstones: above it wasn't yet updated with closedTime etc. + abandonedTombstones.put(key, executeHolder.getExecuteInfo) } } private[connect] def getExecuteHolder(key: ExecuteKey): Option[ExecuteHolder] = { - executionsLock.synchronized { - executions.get(key) - } + Option(executions.get(key)) } private[connect] def removeAllExecutionsForSession(key: SessionKey): Unit = { - val sessionExecutionHolders = executionsLock.synchronized { - executions.filter(_._2.sessionHolder.key == key) - } - sessionExecutionHolders.foreach { case (_, executeHolder) => + var sessionExecutionHolders = mutable.ArrayBuffer[ExecuteHolder]() + executions.forEach((_, executeHolder) => { + if (executeHolder.sessionHolder.key == key) { + sessionExecutionHolders += executeHolder + } + }) + + sessionExecutionHolders.foreach { executeHolder => val info = executeHolder.getExecuteInfo logInfo( log"Execution ${MDC(LogKeys.EXECUTE_INFO, info)} removed in removeSessionExecutions.") @@ -161,11 +184,11 @@ private[connect] class SparkConnectExecutionManager() extends Logging { * If there are no executions, return Left with System.currentTimeMillis of last active * execution. Otherwise return Right with list of ExecuteInfo of all executions. */ - def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = executionsLock.synchronized { + def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = { if (executions.isEmpty) { - Left(lastExecutionTimeMs.get) + Left(lastExecutionTimeMs.getAcquire()) } else { - Right(executions.values.map(_.getExecuteInfo).toBuffer.toSeq) + Right(executions.values().asScala.map(_.getExecuteInfo).toBuffer.toSeq) } } @@ -177,17 +200,24 @@ private[connect] class SparkConnectExecutionManager() extends Logging { abandonedTombstones.asMap.asScala.values.toSeq } - private[connect] def shutdown(): Unit = executionsLock.synchronized { - scheduledExecutor.foreach { executor => + private[connect] def shutdown(): Unit = { + val executor = scheduledExecutor.getAndSet(null) + if (executor != null) { ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) } - scheduledExecutor = None + // note: this does not cleanly shut down the executions, but the server is shutting down. executions.clear() abandonedTombstones.invalidateAll() - if (lastExecutionTimeMs.isEmpty) { - lastExecutionTimeMs = Some(System.currentTimeMillis()) - } + + updateLastExecutionTime() + } + + /** + * Updates the last execution time after the last execution has been removed. + */ + private def updateLastExecutionTime(): Unit = { + lastExecutionTimeMs.getAndUpdate(prev => prev.max(System.currentTimeMillis())) } /** @@ -195,16 +225,16 @@ private[connect] class SparkConnectExecutionManager() extends Logging { * for executions that have not been closed, but are left with no RPC attached to them, and * removes them after a timeout. */ - private def schedulePeriodicChecks(): Unit = executionsLock.synchronized { - scheduledExecutor match { - case Some(_) => // Already running. - case None => + private def schedulePeriodicChecks(): Unit = { + var executor = scheduledExecutor.getAcquire() + if (executor == null) { + executor = Executors.newSingleThreadScheduledExecutor() + if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) { val interval = SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_MAINTENANCE_INTERVAL) logInfo( log"Starting thread for cleanup of abandoned executions every " + log"${MDC(LogKeys.INTERVAL, interval)} ms") - scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor()) - scheduledExecutor.get.scheduleAtFixedRate( + executor.scheduleAtFixedRate( () => { try { val timeout = SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT) @@ -216,6 +246,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { interval, interval, TimeUnit.MILLISECONDS) + } } } @@ -225,19 +256,18 @@ private[connect] class SparkConnectExecutionManager() extends Logging { // Find any detached executions that expired and should be removed. val toRemove = new mutable.ArrayBuffer[ExecuteHolder]() - executionsLock.synchronized { - val nowMs = System.currentTimeMillis() - - executions.values.foreach { executeHolder => - executeHolder.lastAttachedRpcTimeMs match { - case Some(detached) => - if (detached + timeout <= nowMs) { - toRemove += executeHolder - } - case _ => // execution is active - } + val nowMs = System.currentTimeMillis() + + executions.forEach((_, executeHolder) => { + executeHolder.lastAttachedRpcTimeMs match { + case Some(detached) => + if (detached + timeout <= nowMs) { + toRemove += executeHolder + } + case _ => // execution is active } - } + }) + // .. and remove them. toRemove.foreach { executeHolder => val info = executeHolder.getExecuteInfo @@ -250,16 +280,16 @@ private[connect] class SparkConnectExecutionManager() extends Logging { } // For testing. - private[connect] def setAllRPCsDeadline(deadlineMs: Long) = executionsLock.synchronized { - executions.values.foreach(_.setGrpcResponseSendersDeadline(deadlineMs)) + private[connect] def setAllRPCsDeadline(deadlineMs: Long) = { + executions.values().asScala.foreach(_.setGrpcResponseSendersDeadline(deadlineMs)) } // For testing. - private[connect] def interruptAllRPCs() = executionsLock.synchronized { - executions.values.foreach(_.interruptGrpcResponseSenders()) + private[connect] def interruptAllRPCs() = { + executions.values().asScala.foreach(_.interruptGrpcResponseSenders()) } - private[connect] def listExecuteHolders: Seq[ExecuteHolder] = executionsLock.synchronized { - executions.values.toSeq + private[connect] def listExecuteHolders: Seq[ExecuteHolder] = { + executions.values().asScala.toSeq } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala index 5b2205757648f..7a0c067ab430b 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala @@ -160,9 +160,11 @@ private[sql] class SparkConnectListenerBusListener( } override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = { - logDebug( - s"[SessionId: ${sessionHolder.sessionId}][UserId: ${sessionHolder.userId}] " + - s"Sending QueryTerminatedEvent to client, id: ${event.id} runId: ${event.runId}.") + logInfo( + log"[SessionId: ${MDC(LogKeys.SESSION_ID, sessionHolder.sessionId)}]" + + log"[UserId: ${MDC(LogKeys.USER_ID, sessionHolder.userId)}] " + + log"Sending QueryTerminatedEvent to client, id: ${MDC(LogKeys.QUERY_ID, event.id)} " + + log"runId: ${MDC(LogKeys.QUERY_RUN_ID, event.runId)}.") send(event.json, StreamingQueryEventType.QUERY_TERMINATED_EVENT) } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala index edaaa640bf12e..4ca3a80bfb985 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.connect.service import java.util.UUID -import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit} -import javax.annotation.concurrent.GuardedBy +import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit} +import java.util.concurrent.atomic.AtomicReference import scala.collection.mutable import scala.concurrent.duration.FiniteDuration @@ -40,10 +40,8 @@ import org.apache.spark.util.ThreadUtils */ class SparkConnectSessionManager extends Logging { - private val sessionsLock = new Object - - @GuardedBy("sessionsLock") - private val sessionStore = mutable.HashMap[SessionKey, SessionHolder]() + private val sessionStore: ConcurrentMap[SessionKey, SessionHolder] = + new ConcurrentHashMap[SessionKey, SessionHolder]() private val closedSessionsCache = CacheBuilder @@ -52,7 +50,8 @@ class SparkConnectSessionManager extends Logging { .build[SessionKey, SessionHolderInfo]() /** Executor for the periodic maintenance */ - private var scheduledExecutor: Option[ScheduledExecutorService] = None + private var scheduledExecutor: AtomicReference[ScheduledExecutorService] = + new AtomicReference[ScheduledExecutorService]() private def validateSessionId( key: SessionKey, @@ -74,8 +73,6 @@ class SparkConnectSessionManager extends Logging { val holder = getSession( key, Some(() => { - // Executed under sessionsState lock in getSession, to guard against concurrent removal - // and insertion into closedSessionsCache. validateSessionCreate(key) val holder = SessionHolder(key.userId, key.sessionId, newIsolatedSession()) holder.initializeSession() @@ -121,43 +118,39 @@ class SparkConnectSessionManager extends Logging { private def getSession(key: SessionKey, default: Option[() => SessionHolder]): SessionHolder = { schedulePeriodicChecks() // Starts the maintenance thread if it hasn't started yet. - sessionsLock.synchronized { - // try to get existing session from store - val sessionOpt = sessionStore.get(key) - // create using default if missing - val session = sessionOpt match { - case Some(s) => s - case None => - default match { - case Some(callable) => - val session = callable() - sessionStore.put(key, session) - session - case None => - null - } - } - // record access time before returning - session match { - case null => - null - case s: SessionHolder => - s.updateAccessTime() - s - } + // Get the existing session from the store or create a new one. + val session = default match { + case Some(callable) => + sessionStore.computeIfAbsent(key, _ => callable()) + case None => + sessionStore.get(key) + } + + // Record the access time before returning the session holder. + if (session != null) { + session.updateAccessTime() } + + session } // Removes session from sessionStore and returns it. private def removeSessionHolder(key: SessionKey): Option[SessionHolder] = { var sessionHolder: Option[SessionHolder] = None - sessionsLock.synchronized { - sessionHolder = sessionStore.remove(key) - sessionHolder.foreach { s => - // Put into closedSessionsCache, so that it cannot get accidentally recreated - // by getOrCreateIsolatedSession. - closedSessionsCache.put(s.key, s.getSessionHolderInfo) - } + + // The session holder should remain in the session store until it is added to the closed session + // cache, because of a subtle data race: a new session with the same key can be created if the + // closed session cache does not contain the key right after the key has been removed from the + // session store. + sessionHolder = Option(sessionStore.get(key)) + + sessionHolder.foreach { s => + // Put into closedSessionsCache to prevent the same session from being recreated by + // getOrCreateIsolatedSession. + closedSessionsCache.put(s.key, s.getSessionHolderInfo) + + // Then, remove the session holder from the session store. + sessionStore.remove(key) } sessionHolder } @@ -171,26 +164,26 @@ class SparkConnectSessionManager extends Logging { def closeSession(key: SessionKey): Unit = { val sessionHolder = removeSessionHolder(key) - // Rest of the cleanup outside sessionLock - the session cannot be accessed anymore by - // getOrCreateIsolatedSession. + // Rest of the cleanup: the session cannot be accessed anymore by getOrCreateIsolatedSession. sessionHolder.foreach(shutdownSessionHolder(_)) } - private[connect] def shutdown(): Unit = sessionsLock.synchronized { - scheduledExecutor.foreach { executor => + private[connect] def shutdown(): Unit = { + val executor = scheduledExecutor.getAndSet(null) + if (executor != null) { ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) } - scheduledExecutor = None + // note: this does not cleanly shut down the sessions, but the server is shutting down. sessionStore.clear() closedSessionsCache.invalidateAll() } - def listActiveSessions: Seq[SessionHolderInfo] = sessionsLock.synchronized { - sessionStore.values.map(_.getSessionHolderInfo).toSeq + def listActiveSessions: Seq[SessionHolderInfo] = { + sessionStore.values().asScala.map(_.getSessionHolderInfo).toSeq } - def listClosedSessions: Seq[SessionHolderInfo] = sessionsLock.synchronized { + def listClosedSessions: Seq[SessionHolderInfo] = { closedSessionsCache.asMap.asScala.values.toSeq } @@ -199,16 +192,16 @@ class SparkConnectSessionManager extends Logging { * * The checks are looking to remove sessions that expired. */ - private def schedulePeriodicChecks(): Unit = sessionsLock.synchronized { - scheduledExecutor match { - case Some(_) => // Already running. - case None => + private def schedulePeriodicChecks(): Unit = { + var executor = scheduledExecutor.getAcquire() + if (executor == null) { + executor = Executors.newSingleThreadScheduledExecutor() + if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) { val interval = SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_MAINTENANCE_INTERVAL) logInfo( log"Starting thread for cleanup of expired sessions every " + log"${MDC(INTERVAL, interval)} ms") - scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor()) - scheduledExecutor.get.scheduleAtFixedRate( + executor.scheduleAtFixedRate( () => { try { val defaultInactiveTimeoutMs = @@ -221,6 +214,7 @@ class SparkConnectSessionManager extends Logging { interval, interval, TimeUnit.MILLISECONDS) + } } } @@ -246,34 +240,27 @@ class SparkConnectSessionManager extends Logging { timeoutMs != -1 && info.lastAccessTimeMs + timeoutMs <= nowMs } - sessionsLock.synchronized { - val nowMs = System.currentTimeMillis() - sessionStore.values.foreach { sessionHolder => - if (shouldExpire(sessionHolder.getSessionHolderInfo, nowMs)) { - toRemove += sessionHolder - } + val nowMs = System.currentTimeMillis() + sessionStore.forEach((_, sessionHolder) => { + if (shouldExpire(sessionHolder.getSessionHolderInfo, nowMs)) { + toRemove += sessionHolder } - } + }) + // .. and remove them. toRemove.foreach { sessionHolder => - // This doesn't use closeSession to be able to do the extra last chance check under lock. - val removedSession = sessionsLock.synchronized { - // Last chance - check expiration time and remove under lock if expired. - val info = sessionHolder.getSessionHolderInfo - if (shouldExpire(info, System.currentTimeMillis())) { - logInfo( - log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " + - log"and will be closed.") - removeSessionHolder(info.key) - } else { - None + val info = sessionHolder.getSessionHolderInfo + if (shouldExpire(info, System.currentTimeMillis())) { + logInfo( + log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " + + log"and will be closed.") + removeSessionHolder(info.key) + try { + shutdownSessionHolder(sessionHolder) + } catch { + case NonFatal(ex) => logWarning("Unexpected exception closing session", ex) } } - // do shutdown and cleanup outside of lock. - try removedSession.foreach(shutdownSessionHolder(_)) - catch { - case NonFatal(ex) => logWarning("Unexpected exception closing session", ex) - } } logInfo("Finished periodic run of SparkConnectSessionManager maintenance.") } @@ -309,7 +296,7 @@ class SparkConnectSessionManager extends Logging { /** * Used for testing */ - private[connect] def invalidateAllSessions(): Unit = sessionsLock.synchronized { + private[connect] def invalidateAllSessions(): Unit = { periodicMaintenance(defaultInactiveTimeoutMs = 0L, ignoreCustomTimeout = true) assert(sessionStore.isEmpty) closedSessionsCache.invalidateAll() diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala index 03719ddd87419..8241672d5107b 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.service import java.util.concurrent.Executors import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -185,10 +186,10 @@ private[connect] class SparkConnectStreamingQueryCache( // Visible for testing. private[service] def shutdown(): Unit = queryCacheLock.synchronized { - scheduledExecutor.foreach { executor => + val executor = scheduledExecutor.getAndSet(null) + if (executor != null) { ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) } - scheduledExecutor = None } @GuardedBy("queryCacheLock") @@ -199,19 +200,19 @@ private[connect] class SparkConnectStreamingQueryCache( private val taggedQueries = new mutable.HashMap[String, mutable.ArrayBuffer[QueryCacheKey]] private val taggedQueriesLock = new Object - @GuardedBy("queryCacheLock") - private var scheduledExecutor: Option[ScheduledExecutorService] = None + private var scheduledExecutor: AtomicReference[ScheduledExecutorService] = + new AtomicReference[ScheduledExecutorService]() /** Schedules periodic checks if it is not already scheduled */ - private def schedulePeriodicChecks(): Unit = queryCacheLock.synchronized { - scheduledExecutor match { - case Some(_) => // Already running. - case None => + private def schedulePeriodicChecks(): Unit = { + var executor = scheduledExecutor.getAcquire() + if (executor == null) { + executor = Executors.newSingleThreadScheduledExecutor() + if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) { logInfo( log"Starting thread for polling streaming sessions " + log"every ${MDC(DURATION, sessionPollingPeriod.toMillis)}") - scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor()) - scheduledExecutor.get.scheduleAtFixedRate( + executor.scheduleAtFixedRate( () => { try periodicMaintenance() catch { @@ -221,6 +222,7 @@ private[connect] class SparkConnectStreamingQueryCache( sessionPollingPeriod.toMillis, sessionPollingPeriod.toMillis, TimeUnit.MILLISECONDS) + } } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala index 355048cf30363..f1636ed1ef092 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -205,7 +205,9 @@ private[connect] object ErrorUtils extends Logging { case _ => } - if (sessionHolderOpt.exists(_.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED))) { + val enrichErrorEnabled = sessionHolderOpt.exists( + _.session.sessionState.conf.getConf(Connect.CONNECT_ENRICH_ERROR_ENABLED)) + if (enrichErrorEnabled) { // Generate a new unique key for this exception. val errorId = UUID.randomUUID().toString @@ -216,9 +218,10 @@ private[connect] object ErrorUtils extends Logging { } lazy val stackTrace = Option(ExceptionUtils.getStackTrace(st)) + val stackTraceEnabled = sessionHolderOpt.exists( + _.session.sessionState.conf.getConf(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED)) val withStackTrace = - if (sessionHolderOpt.exists( - _.session.conf.get(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) && stackTrace.nonEmpty)) { + if (stackTraceEnabled && stackTrace.nonEmpty) { val maxSize = Math.min( SparkEnv.get.conf.get(Connect.CONNECT_JVM_STACK_TRACE_MAX_SIZE), maxMetadataSize) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala index 4f6daded402a9..29ad97ad9fbe8 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala @@ -179,6 +179,11 @@ class ProtoToParsedPlanTestSuite logError(log"Skipping ${MDC(PATH, fileName)}") return } + // TODO: enable below by SPARK-49487 + if (fileName.contains("transpose")) { + logError(log"Skipping ${MDC(PATH, fileName)} because of SPARK-49487") + return + } val name = fileName.stripSuffix(".proto.bin") test(name) { val relation = readRelation(file) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala index 7c1b9362425d9..df5df23e77505 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala @@ -79,11 +79,16 @@ class ConnectProgressExecutionListenerSuite extends SparkFunSuite with MockitoSu } } - test("taskDone") { + def testTaskDone(metricsPopulated: Boolean): Unit = { val listener = new ConnectProgressExecutionListener listener.registerJobTag(testTag) listener.onJobStart(testJobStart) + val metricsOrNull = if (metricsPopulated) { + testStage1Task1Metrics + } else { + null + } // Finish the tasks val taskEnd = SparkListenerTaskEnd( 1, @@ -92,7 +97,7 @@ class ConnectProgressExecutionListenerSuite extends SparkFunSuite with MockitoSu Success, testStage1Task1, testStage1Task1ExecutorMetrics, - testStage1Task1Metrics) + metricsOrNull) val t = listener.trackedTags(testTag) var yielded = false @@ -117,7 +122,11 @@ class ConnectProgressExecutionListenerSuite extends SparkFunSuite with MockitoSu assert(stages.map(_.numTasks).sum == 2) assert(stages.map(_.completedTasks).sum == 1) assert(stages.size == 2) - assert(stages.map(_.inputBytesRead).sum == 500) + if (metricsPopulated) { + assert(stages.map(_.inputBytesRead).sum == 500) + } else { + assert(stages.map(_.inputBytesRead).sum == 0) + } assert( stages .map(_.completed match { @@ -140,7 +149,11 @@ class ConnectProgressExecutionListenerSuite extends SparkFunSuite with MockitoSu assert(stages.map(_.numTasks).sum == 2) assert(stages.map(_.completedTasks).sum == 1) assert(stages.size == 2) - assert(stages.map(_.inputBytesRead).sum == 500) + if (metricsPopulated) { + assert(stages.map(_.inputBytesRead).sum == 500) + } else { + assert(stages.map(_.inputBytesRead).sum == 0) + } assert( stages .map(_.completed match { @@ -153,4 +166,12 @@ class ConnectProgressExecutionListenerSuite extends SparkFunSuite with MockitoSu assert(yielded, "Must updated with results") } + test("taskDone - populated metrics") { + testTaskDone(metricsPopulated = true) + } + + test("taskDone - null metrics") { + testTaskDone(metricsPopulated = false) + } + } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala index 25e6cc48a1998..2606284c25bd5 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala @@ -429,4 +429,27 @@ class ReattachableExecuteSuite extends SparkConnectServerTest { val abandonedExecutions = manager.listAbandonedExecutions assert(abandonedExecutions.forall(_.operationId != dummyOpId)) } + + test("SPARK-49492: reattach must not succeed on an inactive execution holder") { + withRawBlockingStub { stub => + val operationId = UUID.randomUUID().toString + + // supply an invalid plan so that the execute plan handler raises an error + val iter = stub.executePlan( + buildExecutePlanRequest(proto.Plan.newBuilder().build(), operationId = operationId)) + + // expect that the execution fails before spawning an execute thread + val ee = intercept[StatusRuntimeException] { + iter.next() + } + assert(ee.getMessage.contains("INTERNAL")) + + // reattach must fail + val reattach = stub.reattachExecute(buildReattachExecuteRequest(operationId, None)) + val re = intercept[StatusRuntimeException] { + reattach.hasNext() + } + assert(re.getMessage.contains("INVALID_HANDLE.OPERATION_NOT_FOUND")) + } + } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 03bf5a4c10dbc..cad7fe6370827 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -1042,7 +1042,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { analyzePlan( transform(connectTestRelation.observe("my_metric", "id".protoAttr.cast("string")))) }, - errorClass = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE", + condition = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE", parameters = Map("expr" -> "\"CAST(id AS STRING) AS id\"")) val connectPlan2 = @@ -1073,7 +1073,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { transform( connectTestRelation.observe(Observation("my_metric"), "id".protoAttr.cast("string")))) }, - errorClass = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE", + condition = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE", parameters = Map("expr" -> "\"CAST(id AS STRING) AS id\"")) } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 579fdb47aef3c..62146f19328a8 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -871,10 +871,16 @@ class SparkConnectServiceSuite class VerifyEvents(val sparkContext: SparkContext) { val listener: MockSparkListener = new MockSparkListener() val listenerBus = sparkContext.listenerBus + val EVENT_WAIT_TIMEOUT = timeout(10.seconds) val LISTENER_BUS_TIMEOUT = 30000 def executeHolder: ExecuteHolder = { - assert(listener.executeHolder.isDefined) - listener.executeHolder.get + // An ExecuteHolder shall be set eventually through MockSparkListener + Eventually.eventually(EVENT_WAIT_TIMEOUT) { + assert( + listener.executeHolder.isDefined, + s"No events have been posted in $EVENT_WAIT_TIMEOUT") + listener.executeHolder.get + } } def onNext(v: proto.ExecutePlanResponse): Unit = { if (v.hasSchema) { @@ -891,8 +897,10 @@ class SparkConnectServiceSuite def onCompleted(producedRowCount: Option[Long] = None): Unit = { assert(executeHolder.eventsManager.getProducedRowCount == producedRowCount) // The eventsManager is closed asynchronously - Eventually.eventually(timeout(1.seconds)) { - assert(executeHolder.eventsManager.status == ExecuteStatus.Closed) + Eventually.eventually(EVENT_WAIT_TIMEOUT) { + assert( + executeHolder.eventsManager.status == ExecuteStatus.Closed, + s"Execution has not been completed in $EVENT_WAIT_TIMEOUT") } } def onCanceled(): Unit = { diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala index 512cdad62b921..2b768875c6e20 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala @@ -217,7 +217,7 @@ class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConne exception = intercept[SparkException] { SparkConnectPluginRegistry.loadRelationPlugins() }, - errorClass = "CONNECT.PLUGIN_CTOR_MISSING", + condition = "CONNECT.PLUGIN_CTOR_MISSING", parameters = Map("cls" -> "org.apache.spark.sql.connect.plugin.DummyPluginNoTrivialCtor")) } @@ -228,7 +228,7 @@ class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConne exception = intercept[SparkException] { SparkConnectPluginRegistry.loadRelationPlugins() }, - errorClass = "CONNECT.PLUGIN_RUNTIME_ERROR", + condition = "CONNECT.PLUGIN_RUNTIME_ERROR", parameters = Map("msg" -> "Bad Plugin Error")) } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala index dbe8420eab03d..a9843e261fff8 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala @@ -374,7 +374,8 @@ class ExecuteEventsManagerSuite .setClientType(DEFAULT_CLIENT_TYPE) .build() - val executeHolder = new ExecuteHolder(executePlanRequest, sessionHolder) + val executeKey = ExecuteKey(executePlanRequest, sessionHolder) + val executeHolder = new ExecuteHolder(executeKey, executePlanRequest, sessionHolder) val eventsManager = ExecuteEventsManager(executeHolder, DEFAULT_CLOCK) eventsManager.status_(executeStatus) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala index 8f76d58a31476..f125cb2d5c6c0 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala @@ -118,7 +118,7 @@ class InterceptorRegistrySuite extends SharedSparkSession { exception = intercept[SparkException] { SparkConnectInterceptorRegistry.chainInterceptors(sb) }, - errorClass = "CONNECT.INTERCEPTOR_CTOR_MISSING", + condition = "CONNECT.INTERCEPTOR_CTOR_MISSING", parameters = Map("cls" -> "org.apache.spark.sql.connect.service.TestingInterceptorNoTrivialCtor")) } @@ -132,7 +132,7 @@ class InterceptorRegistrySuite extends SharedSparkSession { exception = intercept[SparkException] { SparkConnectInterceptorRegistry.createConfiguredInterceptors() }, - errorClass = "CONNECT.INTERCEPTOR_CTOR_MISSING", + condition = "CONNECT.INTERCEPTOR_CTOR_MISSING", parameters = Map("cls" -> "org.apache.spark.sql.connect.service.TestingInterceptorNoTrivialCtor")) } @@ -144,7 +144,7 @@ class InterceptorRegistrySuite extends SharedSparkSession { exception = intercept[SparkException] { SparkConnectInterceptorRegistry.createConfiguredInterceptors() }, - errorClass = "CONNECT.INTERCEPTOR_RUNTIME_ERROR", + condition = "CONNECT.INTERCEPTOR_RUNTIME_ERROR", parameters = Map("msg" -> "Bad Error")) } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala index beebe5d2e2dc1..ed2f60afb0096 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala @@ -399,7 +399,7 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { test("Test session plan cache - disabled") { val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) // Disable plan cache of the session - sessionHolder.session.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, false) + sessionHolder.session.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED.key, false) val planner = new SparkConnectPlanner(sessionHolder) val query = buildRelation("select 1") diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto index 113bb8a8b00fb..1ff90f27e173a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto @@ -59,6 +59,7 @@ message ImplicitGroupingKeyRequest { message StateCallCommand { string stateName = 1; string schema = 2; + TTLConfig ttl = 3; } message ValueStateCall { @@ -101,3 +102,7 @@ enum HandleState { message SetHandleState { HandleState state = 1; } + +message TTLConfig { + int32 durationMs = 1; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java index 542bcd8a68abf..4fbb20be05b7b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java @@ -15,9 +15,7 @@ * limitations under the License. */ // Generated by the protocol buffer compiler. DO NOT EDIT! -// NO CHECKED-IN PROTOBUF GENCODE // source: StateMessage.proto -// Protobuf Java Version: 4.27.1 package org.apache.spark.sql.execution.streaming.state; @@ -213,7 +211,7 @@ public interface StateRequestOrBuilder extends */ org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequestOrBuilder getImplicitGroupingKeyRequestOrBuilder(); - org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest.MethodCase getMethodCase(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest.MethodCase getMethodCase(); } /** * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.StateRequest} @@ -230,6 +228,18 @@ private StateRequest(com.google.protobuf.GeneratedMessageV3.Builder builder) private StateRequest() { } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new StateRequest(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_StateRequest_descriptor; @@ -244,7 +254,6 @@ private StateRequest() { } private int methodCase_ = 0; - @SuppressWarnings("serial") private java.lang.Object method_; public enum MethodCase implements com.google.protobuf.Internal.EnumLite, @@ -288,7 +297,7 @@ public int getNumber() { } public static final int VERSION_FIELD_NUMBER = 1; - private int version_ = 0; + private int version_; /** * int32 version = 1; * @return The version. @@ -554,13 +563,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateR return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) @@ -635,8 +642,8 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); - bitField0_ = 0; version_ = 0; + if (statefulProcessorCallBuilder_ != null) { statefulProcessorCallBuilder_.clear(); } @@ -674,36 +681,65 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest @java.lang.Override public org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest buildPartial() { org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest result = new org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest(this); - if (bitField0_ != 0) { buildPartial0(result); } - buildPartialOneofs(result); + result.version_ = version_; + if (methodCase_ == 2) { + if (statefulProcessorCallBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = statefulProcessorCallBuilder_.build(); + } + } + if (methodCase_ == 3) { + if (stateVariableRequestBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = stateVariableRequestBuilder_.build(); + } + } + if (methodCase_ == 4) { + if (implicitGroupingKeyRequestBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = implicitGroupingKeyRequestBuilder_.build(); + } + } + result.methodCase_ = methodCase_; onBuilt(); return result; } - private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest result) { - int from_bitField0_ = bitField0_; - if (((from_bitField0_ & 0x00000001) != 0)) { - result.version_ = version_; - } + @java.lang.Override + public Builder clone() { + return super.clone(); } - - private void buildPartialOneofs(org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest result) { - result.methodCase_ = methodCase_; - result.method_ = this.method_; - if (methodCase_ == 2 && - statefulProcessorCallBuilder_ != null) { - result.method_ = statefulProcessorCallBuilder_.build(); - } - if (methodCase_ == 3 && - stateVariableRequestBuilder_ != null) { - result.method_ = stateVariableRequestBuilder_.build(); - } - if (methodCase_ == 4 && - implicitGroupingKeyRequestBuilder_ != null) { - result.method_ = implicitGroupingKeyRequestBuilder_.build(); - } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); } - @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest) { @@ -764,7 +800,7 @@ public Builder mergeFrom( break; case 8: { version_ = input.readInt32(); - bitField0_ |= 0x00000001; + break; } // case 8 case 18: { @@ -818,7 +854,6 @@ public Builder clearMethod() { return this; } - private int bitField0_; private int version_ ; /** @@ -835,9 +870,8 @@ public int getVersion() { * @return This builder for chaining. */ public Builder setVersion(int value) { - + version_ = value; - bitField0_ |= 0x00000001; onChanged(); return this; } @@ -846,7 +880,7 @@ public Builder setVersion(int value) { * @return This builder for chaining. */ public Builder clearVersion() { - bitField0_ = (bitField0_ & ~0x00000001); + version_ = 0; onChanged(); return this; @@ -990,7 +1024,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProce method_ = null; } methodCase_ = 2; - onChanged(); + onChanged();; return statefulProcessorCallBuilder_; } @@ -1132,7 +1166,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariable method_ = null; } methodCase_ = 3; - onChanged(); + onChanged();; return stateVariableRequestBuilder_; } @@ -1274,9 +1308,21 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroup method_ = null; } methodCase_ = 4; - onChanged(); + onChanged();; return implicitGroupingKeyRequestBuilder_; } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.StateRequest) } @@ -1374,6 +1420,18 @@ private StateResponse() { value_ = com.google.protobuf.ByteString.EMPTY; } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new StateResponse(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_StateResponse_descriptor; @@ -1388,7 +1446,7 @@ private StateResponse() { } public static final int STATUSCODE_FIELD_NUMBER = 1; - private int statusCode_ = 0; + private int statusCode_; /** * int32 statusCode = 1; * @return The statusCode. @@ -1399,8 +1457,7 @@ public int getStatusCode() { } public static final int ERRORMESSAGE_FIELD_NUMBER = 2; - @SuppressWarnings("serial") - private volatile java.lang.Object errorMessage_ = ""; + private volatile java.lang.Object errorMessage_; /** * string errorMessage = 2; * @return The errorMessage. @@ -1438,7 +1495,7 @@ public java.lang.String getErrorMessage() { } public static final int VALUE_FIELD_NUMBER = 3; - private com.google.protobuf.ByteString value_ = com.google.protobuf.ByteString.EMPTY; + private com.google.protobuf.ByteString value_; /** * bytes value = 3; * @return The value. @@ -1578,13 +1635,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateR return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponse parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponse parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) @@ -1659,10 +1714,12 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); - bitField0_ = 0; statusCode_ = 0; + errorMessage_ = ""; + value_ = com.google.protobuf.ByteString.EMPTY; + return this; } @@ -1689,24 +1746,45 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponse @java.lang.Override public org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponse buildPartial() { org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponse result = new org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponse(this); - if (bitField0_ != 0) { buildPartial0(result); } + result.statusCode_ = statusCode_; + result.errorMessage_ = errorMessage_; + result.value_ = value_; onBuilt(); return result; } - private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponse result) { - int from_bitField0_ = bitField0_; - if (((from_bitField0_ & 0x00000001) != 0)) { - result.statusCode_ = statusCode_; - } - if (((from_bitField0_ & 0x00000002) != 0)) { - result.errorMessage_ = errorMessage_; - } - if (((from_bitField0_ & 0x00000004) != 0)) { - result.value_ = value_; - } + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); } - @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponse) { @@ -1724,7 +1802,6 @@ public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMes } if (!other.getErrorMessage().isEmpty()) { errorMessage_ = other.errorMessage_; - bitField0_ |= 0x00000002; onChanged(); } if (other.getValue() != com.google.protobuf.ByteString.EMPTY) { @@ -1758,17 +1835,17 @@ public Builder mergeFrom( break; case 8: { statusCode_ = input.readInt32(); - bitField0_ |= 0x00000001; + break; } // case 8 case 18: { errorMessage_ = input.readStringRequireUtf8(); - bitField0_ |= 0x00000002; + break; } // case 18 case 26: { value_ = input.readBytes(); - bitField0_ |= 0x00000004; + break; } // case 26 default: { @@ -1786,7 +1863,6 @@ public Builder mergeFrom( } // finally return this; } - private int bitField0_; private int statusCode_ ; /** @@ -1803,9 +1879,8 @@ public int getStatusCode() { * @return This builder for chaining. */ public Builder setStatusCode(int value) { - + statusCode_ = value; - bitField0_ |= 0x00000001; onChanged(); return this; } @@ -1814,7 +1889,7 @@ public Builder setStatusCode(int value) { * @return This builder for chaining. */ public Builder clearStatusCode() { - bitField0_ = (bitField0_ & ~0x00000001); + statusCode_ = 0; onChanged(); return this; @@ -1861,9 +1936,11 @@ public java.lang.String getErrorMessage() { */ public Builder setErrorMessage( java.lang.String value) { - if (value == null) { throw new NullPointerException(); } + if (value == null) { + throw new NullPointerException(); + } + errorMessage_ = value; - bitField0_ |= 0x00000002; onChanged(); return this; } @@ -1872,8 +1949,8 @@ public Builder setErrorMessage( * @return This builder for chaining. */ public Builder clearErrorMessage() { + errorMessage_ = getDefaultInstance().getErrorMessage(); - bitField0_ = (bitField0_ & ~0x00000002); onChanged(); return this; } @@ -1884,10 +1961,12 @@ public Builder clearErrorMessage() { */ public Builder setErrorMessageBytes( com.google.protobuf.ByteString value) { - if (value == null) { throw new NullPointerException(); } - checkByteStringIsUtf8(value); + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + errorMessage_ = value; - bitField0_ |= 0x00000002; onChanged(); return this; } @@ -1907,9 +1986,11 @@ public com.google.protobuf.ByteString getValue() { * @return This builder for chaining. */ public Builder setValue(com.google.protobuf.ByteString value) { - if (value == null) { throw new NullPointerException(); } + if (value == null) { + throw new NullPointerException(); + } + value_ = value; - bitField0_ |= 0x00000004; onChanged(); return this; } @@ -1918,11 +1999,23 @@ public Builder setValue(com.google.protobuf.ByteString value) { * @return This builder for chaining. */ public Builder clearValue() { - bitField0_ = (bitField0_ & ~0x00000004); + value_ = getDefaultInstance().getValue(); onChanged(); return this; } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.StateResponse) } @@ -2039,7 +2132,7 @@ public interface StatefulProcessorCallOrBuilder extends */ org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommandOrBuilder getGetMapStateOrBuilder(); - org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall.MethodCase getMethodCase(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall.MethodCase getMethodCase(); } /** * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.StatefulProcessorCall} @@ -2056,6 +2149,18 @@ private StatefulProcessorCall(com.google.protobuf.GeneratedMessageV3.Builder private StatefulProcessorCall() { } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new StatefulProcessorCall(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_StatefulProcessorCall_descriptor; @@ -2070,7 +2175,6 @@ private StatefulProcessorCall() { } private int methodCase_ = 0; - @SuppressWarnings("serial") private java.lang.Object method_; public enum MethodCase implements com.google.protobuf.Internal.EnumLite, @@ -2406,13 +2510,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Statef return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) @@ -2487,7 +2589,6 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); - bitField0_ = 0; if (setHandleStateBuilder_ != null) { setHandleStateBuilder_.clear(); } @@ -2528,37 +2629,71 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProce @java.lang.Override public org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall buildPartial() { org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall result = new org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall(this); - if (bitField0_ != 0) { buildPartial0(result); } - buildPartialOneofs(result); - onBuilt(); - return result; - } - - private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall result) { - int from_bitField0_ = bitField0_; - } - - private void buildPartialOneofs(org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall result) { - result.methodCase_ = methodCase_; - result.method_ = this.method_; - if (methodCase_ == 1 && - setHandleStateBuilder_ != null) { - result.method_ = setHandleStateBuilder_.build(); + if (methodCase_ == 1) { + if (setHandleStateBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = setHandleStateBuilder_.build(); + } } - if (methodCase_ == 2 && - getValueStateBuilder_ != null) { - result.method_ = getValueStateBuilder_.build(); + if (methodCase_ == 2) { + if (getValueStateBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = getValueStateBuilder_.build(); + } } - if (methodCase_ == 3 && - getListStateBuilder_ != null) { - result.method_ = getListStateBuilder_.build(); + if (methodCase_ == 3) { + if (getListStateBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = getListStateBuilder_.build(); + } } - if (methodCase_ == 4 && - getMapStateBuilder_ != null) { - result.method_ = getMapStateBuilder_.build(); + if (methodCase_ == 4) { + if (getMapStateBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = getMapStateBuilder_.build(); + } } + result.methodCase_ = methodCase_; + onBuilt(); + return result; } + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall) { @@ -2676,7 +2811,6 @@ public Builder clearMethod() { return this; } - private int bitField0_; private com.google.protobuf.SingleFieldBuilderV3< org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleState, org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleState.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleStateOrBuilder> setHandleStateBuilder_; @@ -2816,7 +2950,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleStat method_ = null; } methodCase_ = 1; - onChanged(); + onChanged();; return setHandleStateBuilder_; } @@ -2958,7 +3092,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallComm method_ = null; } methodCase_ = 2; - onChanged(); + onChanged();; return getValueStateBuilder_; } @@ -3100,7 +3234,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallComm method_ = null; } methodCase_ = 3; - onChanged(); + onChanged();; return getListStateBuilder_; } @@ -3242,9 +3376,21 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallComm method_ = null; } methodCase_ = 4; - onChanged(); + onChanged();; return getMapStateBuilder_; } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.StatefulProcessorCall) } @@ -3316,7 +3462,7 @@ public interface StateVariableRequestOrBuilder extends */ org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCallOrBuilder getValueStateCallOrBuilder(); - org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest.MethodCase getMethodCase(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest.MethodCase getMethodCase(); } /** * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.StateVariableRequest} @@ -3333,6 +3479,18 @@ private StateVariableRequest(com.google.protobuf.GeneratedMessageV3.Builder b private StateVariableRequest() { } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new StateVariableRequest(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_StateVariableRequest_descriptor; @@ -3347,7 +3505,6 @@ private StateVariableRequest() { } private int methodCase_ = 0; - @SuppressWarnings("serial") private java.lang.Object method_; public enum MethodCase implements com.google.protobuf.Internal.EnumLite, @@ -3539,13 +3696,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateV return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) @@ -3620,7 +3775,6 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); - bitField0_ = 0; if (valueStateCallBuilder_ != null) { valueStateCallBuilder_.clear(); } @@ -3652,25 +3806,50 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariable @java.lang.Override public org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest buildPartial() { org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest result = new org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest(this); - if (bitField0_ != 0) { buildPartial0(result); } - buildPartialOneofs(result); + if (methodCase_ == 1) { + if (valueStateCallBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = valueStateCallBuilder_.build(); + } + } + result.methodCase_ = methodCase_; onBuilt(); return result; } - private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest result) { - int from_bitField0_ = bitField0_; + @java.lang.Override + public Builder clone() { + return super.clone(); } - - private void buildPartialOneofs(org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest result) { - result.methodCase_ = methodCase_; - result.method_ = this.method_; - if (methodCase_ == 1 && - valueStateCallBuilder_ != null) { - result.method_ = valueStateCallBuilder_.build(); - } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); } - @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest) { @@ -3755,7 +3934,6 @@ public Builder clearMethod() { return this; } - private int bitField0_; private com.google.protobuf.SingleFieldBuilderV3< org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall, org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCallOrBuilder> valueStateCallBuilder_; @@ -3895,9 +4073,21 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCal method_ = null; } methodCase_ = 1; - onChanged(); + onChanged();; return valueStateCallBuilder_; } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.StateVariableRequest) } @@ -3984,7 +4174,7 @@ public interface ImplicitGroupingKeyRequestOrBuilder extends */ org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKeyOrBuilder getRemoveImplicitKeyOrBuilder(); - org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest.MethodCase getMethodCase(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest.MethodCase getMethodCase(); } /** * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequest} @@ -4001,6 +4191,18 @@ private ImplicitGroupingKeyRequest(com.google.protobuf.GeneratedMessageV3.Builde private ImplicitGroupingKeyRequest() { } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new ImplicitGroupingKeyRequest(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ImplicitGroupingKeyRequest_descriptor; @@ -4015,7 +4217,6 @@ private ImplicitGroupingKeyRequest() { } private int methodCase_ = 0; - @SuppressWarnings("serial") private java.lang.Object method_; public enum MethodCase implements com.google.protobuf.Internal.EnumLite, @@ -4255,13 +4456,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Implic return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) @@ -4336,7 +4535,6 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); - bitField0_ = 0; if (setImplicitKeyBuilder_ != null) { setImplicitKeyBuilder_.clear(); } @@ -4371,29 +4569,57 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroup @java.lang.Override public org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest buildPartial() { org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest(this); - if (bitField0_ != 0) { buildPartial0(result); } - buildPartialOneofs(result); + if (methodCase_ == 1) { + if (setImplicitKeyBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = setImplicitKeyBuilder_.build(); + } + } + if (methodCase_ == 2) { + if (removeImplicitKeyBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = removeImplicitKeyBuilder_.build(); + } + } + result.methodCase_ = methodCase_; onBuilt(); return result; } - private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest result) { - int from_bitField0_ = bitField0_; + @java.lang.Override + public Builder clone() { + return super.clone(); } - - private void buildPartialOneofs(org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest result) { - result.methodCase_ = methodCase_; - result.method_ = this.method_; - if (methodCase_ == 1 && - setImplicitKeyBuilder_ != null) { - result.method_ = setImplicitKeyBuilder_.build(); - } - if (methodCase_ == 2 && - removeImplicitKeyBuilder_ != null) { - result.method_ = removeImplicitKeyBuilder_.build(); - } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); } - @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest) { @@ -4489,7 +4715,6 @@ public Builder clearMethod() { return this; } - private int bitField0_; private com.google.protobuf.SingleFieldBuilderV3< org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey, org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKeyOrBuilder> setImplicitKeyBuilder_; @@ -4629,7 +4854,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKe method_ = null; } methodCase_ = 1; - onChanged(); + onChanged();; return setImplicitKeyBuilder_; } @@ -4771,9 +4996,21 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplici method_ = null; } methodCase_ = 2; - onChanged(); + onChanged();; return removeImplicitKeyBuilder_; } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequest) } @@ -4853,6 +5090,21 @@ public interface StateCallCommandOrBuilder extends */ com.google.protobuf.ByteString getSchemaBytes(); + + /** + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * @return Whether the ttl field is set. + */ + boolean hasTtl(); + /** + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * @return The ttl. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig getTtl(); + /** + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder getTtlOrBuilder(); } /** * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.StateCallCommand} @@ -4871,6 +5123,18 @@ private StateCallCommand() { schema_ = ""; } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new StateCallCommand(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_StateCallCommand_descriptor; @@ -4885,8 +5149,7 @@ private StateCallCommand() { } public static final int STATENAME_FIELD_NUMBER = 1; - @SuppressWarnings("serial") - private volatile java.lang.Object stateName_ = ""; + private volatile java.lang.Object stateName_; /** * string stateName = 1; * @return The stateName. @@ -4924,8 +5187,7 @@ public java.lang.String getStateName() { } public static final int SCHEMA_FIELD_NUMBER = 2; - @SuppressWarnings("serial") - private volatile java.lang.Object schema_ = ""; + private volatile java.lang.Object schema_; /** * string schema = 2; * @return The schema. @@ -4962,6 +5224,32 @@ public java.lang.String getSchema() { } } + public static final int TTL_FIELD_NUMBER = 3; + private org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig ttl_; + /** + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * @return Whether the ttl field is set. + */ + @java.lang.Override + public boolean hasTtl() { + return ttl_ != null; + } + /** + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * @return The ttl. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig getTtl() { + return ttl_ == null ? org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.getDefaultInstance() : ttl_; + } + /** + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder getTtlOrBuilder() { + return getTtl(); + } + private byte memoizedIsInitialized = -1; @java.lang.Override public final boolean isInitialized() { @@ -4982,6 +5270,9 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(schema_)) { com.google.protobuf.GeneratedMessageV3.writeString(output, 2, schema_); } + if (ttl_ != null) { + output.writeMessage(3, getTtl()); + } getUnknownFields().writeTo(output); } @@ -4997,6 +5288,10 @@ public int getSerializedSize() { if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(schema_)) { size += com.google.protobuf.GeneratedMessageV3.computeStringSize(2, schema_); } + if (ttl_ != null) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(3, getTtl()); + } size += getUnknownFields().getSerializedSize(); memoizedSize = size; return size; @@ -5016,6 +5311,11 @@ public boolean equals(final java.lang.Object obj) { .equals(other.getStateName())) return false; if (!getSchema() .equals(other.getSchema())) return false; + if (hasTtl() != other.hasTtl()) return false; + if (hasTtl()) { + if (!getTtl() + .equals(other.getTtl())) return false; + } if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; } @@ -5031,6 +5331,10 @@ public int hashCode() { hash = (53 * hash) + getStateName().hashCode(); hash = (37 * hash) + SCHEMA_FIELD_NUMBER; hash = (53 * hash) + getSchema().hashCode(); + if (hasTtl()) { + hash = (37 * hash) + TTL_FIELD_NUMBER; + hash = (53 * hash) + getTtl().hashCode(); + } hash = (29 * hash) + getUnknownFields().hashCode(); memoizedHashCode = hash; return hash; @@ -5080,13 +5384,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateC return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommand parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommand parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) @@ -5161,9 +5463,16 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); - bitField0_ = 0; stateName_ = ""; + schema_ = ""; + + if (ttlBuilder_ == null) { + ttl_ = null; + } else { + ttl_ = null; + ttlBuilder_ = null; + } return this; } @@ -5190,21 +5499,49 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallComm @java.lang.Override public org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommand buildPartial() { org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommand result = new org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommand(this); - if (bitField0_ != 0) { buildPartial0(result); } + result.stateName_ = stateName_; + result.schema_ = schema_; + if (ttlBuilder_ == null) { + result.ttl_ = ttl_; + } else { + result.ttl_ = ttlBuilder_.build(); + } onBuilt(); return result; } - private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommand result) { - int from_bitField0_ = bitField0_; - if (((from_bitField0_ & 0x00000001) != 0)) { - result.stateName_ = stateName_; - } - if (((from_bitField0_ & 0x00000002) != 0)) { - result.schema_ = schema_; - } + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); } - @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommand) { @@ -5219,14 +5556,15 @@ public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMes if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommand.getDefaultInstance()) return this; if (!other.getStateName().isEmpty()) { stateName_ = other.stateName_; - bitField0_ |= 0x00000001; onChanged(); } if (!other.getSchema().isEmpty()) { schema_ = other.schema_; - bitField0_ |= 0x00000002; onChanged(); } + if (other.hasTtl()) { + mergeTtl(other.getTtl()); + } this.mergeUnknownFields(other.getUnknownFields()); onChanged(); return this; @@ -5255,14 +5593,21 @@ public Builder mergeFrom( break; case 10: { stateName_ = input.readStringRequireUtf8(); - bitField0_ |= 0x00000001; + break; } // case 10 case 18: { schema_ = input.readStringRequireUtf8(); - bitField0_ |= 0x00000002; + break; } // case 18 + case 26: { + input.readMessage( + getTtlFieldBuilder().getBuilder(), + extensionRegistry); + + break; + } // case 26 default: { if (!super.parseUnknownField(input, extensionRegistry, tag)) { done = true; // was an endgroup tag @@ -5278,7 +5623,6 @@ public Builder mergeFrom( } // finally return this; } - private int bitField0_; private java.lang.Object stateName_ = ""; /** @@ -5321,9 +5665,11 @@ public java.lang.String getStateName() { */ public Builder setStateName( java.lang.String value) { - if (value == null) { throw new NullPointerException(); } + if (value == null) { + throw new NullPointerException(); + } + stateName_ = value; - bitField0_ |= 0x00000001; onChanged(); return this; } @@ -5332,8 +5678,8 @@ public Builder setStateName( * @return This builder for chaining. */ public Builder clearStateName() { + stateName_ = getDefaultInstance().getStateName(); - bitField0_ = (bitField0_ & ~0x00000001); onChanged(); return this; } @@ -5344,10 +5690,12 @@ public Builder clearStateName() { */ public Builder setStateNameBytes( com.google.protobuf.ByteString value) { - if (value == null) { throw new NullPointerException(); } - checkByteStringIsUtf8(value); + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + stateName_ = value; - bitField0_ |= 0x00000001; onChanged(); return this; } @@ -5393,9 +5741,11 @@ public java.lang.String getSchema() { */ public Builder setSchema( java.lang.String value) { - if (value == null) { throw new NullPointerException(); } + if (value == null) { + throw new NullPointerException(); + } + schema_ = value; - bitField0_ |= 0x00000002; onChanged(); return this; } @@ -5404,8 +5754,8 @@ public Builder setSchema( * @return This builder for chaining. */ public Builder clearSchema() { + schema_ = getDefaultInstance().getSchema(); - bitField0_ = (bitField0_ & ~0x00000002); onChanged(); return this; } @@ -5416,14 +5766,147 @@ public Builder clearSchema() { */ public Builder setSchemaBytes( com.google.protobuf.ByteString value) { - if (value == null) { throw new NullPointerException(); } - checkByteStringIsUtf8(value); + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + schema_ = value; - bitField0_ |= 0x00000002; onChanged(); return this; } + private org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig ttl_; + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder> ttlBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * @return Whether the ttl field is set. + */ + public boolean hasTtl() { + return ttlBuilder_ != null || ttl_ != null; + } + /** + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + * @return The ttl. + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig getTtl() { + if (ttlBuilder_ == null) { + return ttl_ == null ? org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.getDefaultInstance() : ttl_; + } else { + return ttlBuilder_.getMessage(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + */ + public Builder setTtl(org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig value) { + if (ttlBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ttl_ = value; + onChanged(); + } else { + ttlBuilder_.setMessage(value); + } + + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + */ + public Builder setTtl( + org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Builder builderForValue) { + if (ttlBuilder_ == null) { + ttl_ = builderForValue.build(); + onChanged(); + } else { + ttlBuilder_.setMessage(builderForValue.build()); + } + + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + */ + public Builder mergeTtl(org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig value) { + if (ttlBuilder_ == null) { + if (ttl_ != null) { + ttl_ = + org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.newBuilder(ttl_).mergeFrom(value).buildPartial(); + } else { + ttl_ = value; + } + onChanged(); + } else { + ttlBuilder_.mergeFrom(value); + } + + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + */ + public Builder clearTtl() { + if (ttlBuilder_ == null) { + ttl_ = null; + onChanged(); + } else { + ttl_ = null; + ttlBuilder_ = null; + } + + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Builder getTtlBuilder() { + + onChanged(); + return getTtlFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder getTtlOrBuilder() { + if (ttlBuilder_ != null) { + return ttlBuilder_.getMessageOrBuilder(); + } else { + return ttl_ == null ? + org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.getDefaultInstance() : ttl_; + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder> + getTtlFieldBuilder() { + if (ttlBuilder_ == null) { + ttlBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder>( + getTtl(), + getParentForChildren(), + isClean()); + ttl_ = null; + } + return ttlBuilder_; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.StateCallCommand) } @@ -5551,7 +6034,7 @@ public interface ValueStateCallOrBuilder extends */ org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder getClearOrBuilder(); - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall.MethodCase getMethodCase(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall.MethodCase getMethodCase(); } /** * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ValueStateCall} @@ -5569,6 +6052,18 @@ private ValueStateCall() { stateName_ = ""; } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new ValueStateCall(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateCall_descriptor; @@ -5583,7 +6078,6 @@ private ValueStateCall() { } private int methodCase_ = 0; - @SuppressWarnings("serial") private java.lang.Object method_; public enum MethodCase implements com.google.protobuf.Internal.EnumLite, @@ -5629,8 +6123,7 @@ public int getNumber() { } public static final int STATENAME_FIELD_NUMBER = 1; - @SuppressWarnings("serial") - private volatile java.lang.Object stateName_ = ""; + private volatile java.lang.Object stateName_; /** * string stateName = 1; * @return The stateName. @@ -5968,13 +6461,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueS return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) @@ -6049,8 +6540,8 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); - bitField0_ = 0; stateName_ = ""; + if (existsBuilder_ != null) { existsBuilder_.clear(); } @@ -6091,40 +6582,72 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCal @java.lang.Override public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall buildPartial() { org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall(this); - if (bitField0_ != 0) { buildPartial0(result); } - buildPartialOneofs(result); - onBuilt(); - return result; - } - - private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall result) { - int from_bitField0_ = bitField0_; - if (((from_bitField0_ & 0x00000001) != 0)) { - result.stateName_ = stateName_; - } - } - - private void buildPartialOneofs(org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall result) { - result.methodCase_ = methodCase_; - result.method_ = this.method_; - if (methodCase_ == 2 && - existsBuilder_ != null) { - result.method_ = existsBuilder_.build(); + result.stateName_ = stateName_; + if (methodCase_ == 2) { + if (existsBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = existsBuilder_.build(); + } } - if (methodCase_ == 3 && - getBuilder_ != null) { - result.method_ = getBuilder_.build(); + if (methodCase_ == 3) { + if (getBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = getBuilder_.build(); + } } - if (methodCase_ == 4 && - valueStateUpdateBuilder_ != null) { - result.method_ = valueStateUpdateBuilder_.build(); + if (methodCase_ == 4) { + if (valueStateUpdateBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = valueStateUpdateBuilder_.build(); + } } - if (methodCase_ == 5 && - clearBuilder_ != null) { - result.method_ = clearBuilder_.build(); + if (methodCase_ == 5) { + if (clearBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = clearBuilder_.build(); + } } + result.methodCase_ = methodCase_; + onBuilt(); + return result; } + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall) { @@ -6139,7 +6662,6 @@ public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMes if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall.getDefaultInstance()) return this; if (!other.getStateName().isEmpty()) { stateName_ = other.stateName_; - bitField0_ |= 0x00000001; onChanged(); } switch (other.getMethodCase()) { @@ -6191,7 +6713,7 @@ public Builder mergeFrom( break; case 10: { stateName_ = input.readStringRequireUtf8(); - bitField0_ |= 0x00000001; + break; } // case 10 case 18: { @@ -6252,7 +6774,6 @@ public Builder clearMethod() { return this; } - private int bitField0_; private java.lang.Object stateName_ = ""; /** @@ -6295,9 +6816,11 @@ public java.lang.String getStateName() { */ public Builder setStateName( java.lang.String value) { - if (value == null) { throw new NullPointerException(); } + if (value == null) { + throw new NullPointerException(); + } + stateName_ = value; - bitField0_ |= 0x00000001; onChanged(); return this; } @@ -6306,8 +6829,8 @@ public Builder setStateName( * @return This builder for chaining. */ public Builder clearStateName() { + stateName_ = getDefaultInstance().getStateName(); - bitField0_ = (bitField0_ & ~0x00000001); onChanged(); return this; } @@ -6318,10 +6841,12 @@ public Builder clearStateName() { */ public Builder setStateNameBytes( com.google.protobuf.ByteString value) { - if (value == null) { throw new NullPointerException(); } - checkByteStringIsUtf8(value); + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + stateName_ = value; - bitField0_ |= 0x00000001; onChanged(); return this; } @@ -6464,7 +6989,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuild method_ = null; } methodCase_ = 2; - onChanged(); + onChanged();; return existsBuilder_; } @@ -6606,7 +7131,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.GetOrBuilder method_ = null; } methodCase_ = 3; - onChanged(); + onChanged();; return getBuilder_; } @@ -6748,7 +7273,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpd method_ = null; } methodCase_ = 4; - onChanged(); + onChanged();; return valueStateUpdateBuilder_; } @@ -6890,9 +7415,21 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilde method_ = null; } methodCase_ = 5; - onChanged(); + onChanged();; return clearBuilder_; } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ValueStateCall) } @@ -6971,6 +7508,18 @@ private SetImplicitKey() { key_ = com.google.protobuf.ByteString.EMPTY; } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new SetImplicitKey(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; @@ -6985,7 +7534,7 @@ private SetImplicitKey() { } public static final int KEY_FIELD_NUMBER = 1; - private com.google.protobuf.ByteString key_ = com.google.protobuf.ByteString.EMPTY; + private com.google.protobuf.ByteString key_; /** * bytes key = 1; * @return The key. @@ -7104,13 +7653,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImp return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) @@ -7185,8 +7732,8 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); - bitField0_ = 0; key_ = com.google.protobuf.ByteString.EMPTY; + return this; } @@ -7213,18 +7760,43 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKe @java.lang.Override public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey buildPartial() { org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey result = new org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey(this); - if (bitField0_ != 0) { buildPartial0(result); } + result.key_ = key_; onBuilt(); return result; } - private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey result) { - int from_bitField0_ = bitField0_; - if (((from_bitField0_ & 0x00000001) != 0)) { - result.key_ = key_; - } + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); } - @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey) { @@ -7268,7 +7840,7 @@ public Builder mergeFrom( break; case 10: { key_ = input.readBytes(); - bitField0_ |= 0x00000001; + break; } // case 10 default: { @@ -7286,7 +7858,6 @@ public Builder mergeFrom( } // finally return this; } - private int bitField0_; private com.google.protobuf.ByteString key_ = com.google.protobuf.ByteString.EMPTY; /** @@ -7303,9 +7874,11 @@ public com.google.protobuf.ByteString getKey() { * @return This builder for chaining. */ public Builder setKey(com.google.protobuf.ByteString value) { - if (value == null) { throw new NullPointerException(); } + if (value == null) { + throw new NullPointerException(); + } + key_ = value; - bitField0_ |= 0x00000001; onChanged(); return this; } @@ -7314,11 +7887,23 @@ public Builder setKey(com.google.protobuf.ByteString value) { * @return This builder for chaining. */ public Builder clearKey() { - bitField0_ = (bitField0_ & ~0x00000001); + key_ = getDefaultInstance().getKey(); onChanged(); return this; } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) } @@ -7390,6 +7975,18 @@ private RemoveImplicitKey(com.google.protobuf.GeneratedMessageV3.Builder buil private RemoveImplicitKey() { } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new RemoveImplicitKey(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; @@ -7501,13 +8098,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Remove return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) @@ -7612,6 +8207,38 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplici return result; } + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey) { @@ -7665,6 +8292,18 @@ public Builder mergeFrom( } // finally return this; } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) } @@ -7736,6 +8375,18 @@ private Exists(com.google.protobuf.GeneratedMessageV3.Builder builder) { private Exists() { } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new Exists(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; @@ -7847,13 +8498,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) @@ -7958,6 +8607,38 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists buildP return result; } + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) { @@ -8011,6 +8692,18 @@ public Builder mergeFrom( } // finally return this; } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Exists) } @@ -8082,6 +8775,18 @@ private Get(com.google.protobuf.GeneratedMessageV3.Builder builder) { private Get() { } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new Get(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; @@ -8193,13 +8898,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get pa return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) @@ -8304,6 +9007,38 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.Get buildPart return result; } + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Get) { @@ -8357,6 +9092,18 @@ public Builder mergeFrom( } // finally return this; } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Get) } @@ -8435,6 +9182,18 @@ private ValueStateUpdate() { value_ = com.google.protobuf.ByteString.EMPTY; } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new ValueStateUpdate(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; @@ -8449,7 +9208,7 @@ private ValueStateUpdate() { } public static final int VALUE_FIELD_NUMBER = 1; - private com.google.protobuf.ByteString value_ = com.google.protobuf.ByteString.EMPTY; + private com.google.protobuf.ByteString value_; /** * bytes value = 1; * @return The value. @@ -8568,13 +9327,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueS return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) @@ -8649,8 +9406,8 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); - bitField0_ = 0; value_ = com.google.protobuf.ByteString.EMPTY; + return this; } @@ -8677,20 +9434,45 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpd @java.lang.Override public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate buildPartial() { org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate(this); - if (bitField0_ != 0) { buildPartial0(result); } + result.value_ = value_; onBuilt(); return result; } - private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate result) { - int from_bitField0_ = bitField0_; - if (((from_bitField0_ & 0x00000001) != 0)) { - result.value_ = value_; - } + @java.lang.Override + public Builder clone() { + return super.clone(); } - @java.lang.Override - public Builder mergeFrom(com.google.protobuf.Message other) { + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate) { return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate)other); } else { @@ -8732,7 +9514,7 @@ public Builder mergeFrom( break; case 10: { value_ = input.readBytes(); - bitField0_ |= 0x00000001; + break; } // case 10 default: { @@ -8750,7 +9532,6 @@ public Builder mergeFrom( } // finally return this; } - private int bitField0_; private com.google.protobuf.ByteString value_ = com.google.protobuf.ByteString.EMPTY; /** @@ -8767,9 +9548,11 @@ public com.google.protobuf.ByteString getValue() { * @return This builder for chaining. */ public Builder setValue(com.google.protobuf.ByteString value) { - if (value == null) { throw new NullPointerException(); } + if (value == null) { + throw new NullPointerException(); + } + value_ = value; - bitField0_ |= 0x00000001; onChanged(); return this; } @@ -8778,11 +9561,23 @@ public Builder setValue(com.google.protobuf.ByteString value) { * @return This builder for chaining. */ public Builder clearValue() { - bitField0_ = (bitField0_ & ~0x00000001); + value_ = getDefaultInstance().getValue(); onChanged(); return this; } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) } @@ -8854,6 +9649,18 @@ private Clear(com.google.protobuf.GeneratedMessageV3.Builder builder) { private Clear() { } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new Clear(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; @@ -8965,13 +9772,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) @@ -9076,6 +9881,38 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear buildPa return result; } + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) { @@ -9129,6 +9966,18 @@ public Builder mergeFrom( } // finally return this; } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Clear) } @@ -9212,6 +10061,18 @@ private SetHandleState() { state_ = 0; } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new SetHandleState(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetHandleState_descriptor; @@ -9226,7 +10087,7 @@ private SetHandleState() { } public static final int STATE_FIELD_NUMBER = 1; - private int state_ = 0; + private int state_; /** * .org.apache.spark.sql.execution.streaming.state.HandleState state = 1; * @return The enum numeric value on the wire for state. @@ -9239,7 +10100,8 @@ private SetHandleState() { * @return The state. */ @java.lang.Override public org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState getState() { - org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState result = org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState.forNumber(state_); + @SuppressWarnings("deprecation") + org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState result = org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState.valueOf(state_); return result == null ? org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState.UNRECOGNIZED : result; } @@ -9351,13 +10213,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetHan return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleState parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleState parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) @@ -9432,8 +10292,8 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); - bitField0_ = 0; state_ = 0; + return this; } @@ -9460,18 +10320,43 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleStat @java.lang.Override public org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleState buildPartial() { org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleState result = new org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleState(this); - if (bitField0_ != 0) { buildPartial0(result); } + result.state_ = state_; onBuilt(); return result; } - private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleState result) { - int from_bitField0_ = bitField0_; - if (((from_bitField0_ & 0x00000001) != 0)) { - result.state_ = state_; - } + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); } - @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleState) { @@ -9515,7 +10400,7 @@ public Builder mergeFrom( break; case 8: { state_ = input.readEnum(); - bitField0_ |= 0x00000001; + break; } // case 8 default: { @@ -9533,7 +10418,6 @@ public Builder mergeFrom( } // finally return this; } - private int bitField0_; private int state_ = 0; /** @@ -9549,8 +10433,8 @@ public Builder mergeFrom( * @return This builder for chaining. */ public Builder setStateValue(int value) { + state_ = value; - bitField0_ |= 0x00000001; onChanged(); return this; } @@ -9560,7 +10444,8 @@ public Builder setStateValue(int value) { */ @java.lang.Override public org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState getState() { - org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState result = org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState.forNumber(state_); + @SuppressWarnings("deprecation") + org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState result = org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState.valueOf(state_); return result == null ? org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState.UNRECOGNIZED : result; } /** @@ -9572,7 +10457,7 @@ public Builder setState(org.apache.spark.sql.execution.streaming.state.StateMess if (value == null) { throw new NullPointerException(); } - bitField0_ |= 0x00000001; + state_ = value.getNumber(); onChanged(); return this; @@ -9582,11 +10467,23 @@ public Builder setState(org.apache.spark.sql.execution.streaming.state.StateMess * @return This builder for chaining. */ public Builder clearState() { - bitField0_ = (bitField0_ & ~0x00000001); + state_ = 0; onChanged(); return this; } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.SetHandleState) } @@ -9639,6 +10536,476 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleStat } + public interface TTLConfigOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.TTLConfig) + com.google.protobuf.MessageOrBuilder { + + /** + * int32 durationMs = 1; + * @return The durationMs. + */ + int getDurationMs(); + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.TTLConfig} + */ + public static final class TTLConfig extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.TTLConfig) + TTLConfigOrBuilder { + private static final long serialVersionUID = 0L; + // Use TTLConfig.newBuilder() to construct. + private TTLConfig(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private TTLConfig() { + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new TTLConfig(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.class, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Builder.class); + } + + public static final int DURATIONMS_FIELD_NUMBER = 1; + private int durationMs_; + /** + * int32 durationMs = 1; + * @return The durationMs. + */ + @java.lang.Override + public int getDurationMs() { + return durationMs_; + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (durationMs_ != 0) { + output.writeInt32(1, durationMs_); + } + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (durationMs_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(1, durationMs_); + } + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig other = (org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig) obj; + + if (getDurationMs() + != other.getDurationMs()) return false; + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + DURATIONMS_FIELD_NUMBER; + hash = (53 * hash) + getDurationMs(); + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.TTLConfig} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.TTLConfig) + org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.class, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + durationMs_ = 0; + + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig result = new org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig(this); + result.durationMs_ = durationMs_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.getDefaultInstance()) return this; + if (other.getDurationMs() != 0) { + setDurationMs(other.getDurationMs()); + } + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 8: { + durationMs_ = input.readInt32(); + + break; + } // case 8 + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + + private int durationMs_ ; + /** + * int32 durationMs = 1; + * @return The durationMs. + */ + @java.lang.Override + public int getDurationMs() { + return durationMs_; + } + /** + * int32 durationMs = 1; + * @param value The durationMs to set. + * @return This builder for chaining. + */ + public Builder setDurationMs(int value) { + + durationMs_ = value; + onChanged(); + return this; + } + /** + * int32 durationMs = 1; + * @return This builder for chaining. + */ + public Builder clearDurationMs() { + + durationMs_ = 0; + onChanged(); + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.TTLConfig) + } + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.TTLConfig) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public TTLConfig parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + private static final com.google.protobuf.Descriptors.Descriptor internal_static_org_apache_spark_sql_execution_streaming_state_StateRequest_descriptor; private static final @@ -9709,6 +11076,11 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleStat private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internal_static_org_apache_spark_sql_execution_streaming_state_SetHandleState_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_fieldAccessorTable; public static com.google.protobuf.Descriptors.FileDescriptor getDescriptor() { @@ -9749,24 +11121,27 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleStat "ming.state.SetImplicitKeyH\000\022^\n\021removeImp" + "licitKey\030\002 \001(\0132A.org.apache.spark.sql.ex" + "ecution.streaming.state.RemoveImplicitKe" + - "yH\000B\010\n\006method\"5\n\020StateCallCommand\022\021\n\tsta" + - "teName\030\001 \001(\t\022\016\n\006schema\030\002 \001(\t\"\341\002\n\016ValueSt" + - "ateCall\022\021\n\tstateName\030\001 \001(\t\022H\n\006exists\030\002 \001" + - "(\01326.org.apache.spark.sql.execution.stre" + - "aming.state.ExistsH\000\022B\n\003get\030\003 \001(\01323.org." + - "apache.spark.sql.execution.streaming.sta" + - "te.GetH\000\022\\\n\020valueStateUpdate\030\004 \001(\0132@.org" + + "yH\000B\010\n\006method\"}\n\020StateCallCommand\022\021\n\tsta" + + "teName\030\001 \001(\t\022\016\n\006schema\030\002 \001(\t\022F\n\003ttl\030\003 \001(" + + "\01329.org.apache.spark.sql.execution.strea" + + "ming.state.TTLConfig\"\341\002\n\016ValueStateCall\022" + + "\021\n\tstateName\030\001 \001(\t\022H\n\006exists\030\002 \001(\01326.org" + ".apache.spark.sql.execution.streaming.st" + - "ate.ValueStateUpdateH\000\022F\n\005clear\030\005 \001(\01325." + - "org.apache.spark.sql.execution.streaming" + - ".state.ClearH\000B\010\n\006method\"\035\n\016SetImplicitK" + - "ey\022\013\n\003key\030\001 \001(\014\"\023\n\021RemoveImplicitKey\"\010\n\006" + - "Exists\"\005\n\003Get\"!\n\020ValueStateUpdate\022\r\n\005val" + - "ue\030\001 \001(\014\"\007\n\005Clear\"\\\n\016SetHandleState\022J\n\005s" + - "tate\030\001 \001(\0162;.org.apache.spark.sql.execut" + - "ion.streaming.state.HandleState*K\n\013Handl" + - "eState\022\013\n\007CREATED\020\000\022\017\n\013INITIALIZED\020\001\022\022\n\016" + - "DATA_PROCESSED\020\002\022\n\n\006CLOSED\020\003b\006proto3" + "ate.ExistsH\000\022B\n\003get\030\003 \001(\01323.org.apache.s" + + "park.sql.execution.streaming.state.GetH\000" + + "\022\\\n\020valueStateUpdate\030\004 \001(\0132@.org.apache." + + "spark.sql.execution.streaming.state.Valu" + + "eStateUpdateH\000\022F\n\005clear\030\005 \001(\01325.org.apac" + + "he.spark.sql.execution.streaming.state.C" + + "learH\000B\010\n\006method\"\035\n\016SetImplicitKey\022\013\n\003ke" + + "y\030\001 \001(\014\"\023\n\021RemoveImplicitKey\"\010\n\006Exists\"\005" + + "\n\003Get\"!\n\020ValueStateUpdate\022\r\n\005value\030\001 \001(\014" + + "\"\007\n\005Clear\"\\\n\016SetHandleState\022J\n\005state\030\001 \001" + + "(\0162;.org.apache.spark.sql.execution.stre" + + "aming.state.HandleState\"\037\n\tTTLConfig\022\022\n\n" + + "durationMs\030\001 \001(\005*K\n\013HandleState\022\013\n\007CREAT" + + "ED\020\000\022\017\n\013INITIALIZED\020\001\022\022\n\016DATA_PROCESSED\020" + + "\002\022\n\n\006CLOSED\020\003b\006proto3" }; descriptor = com.google.protobuf.Descriptors.FileDescriptor .internalBuildGeneratedFileFrom(descriptorData, @@ -9807,7 +11182,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleStat internal_static_org_apache_spark_sql_execution_streaming_state_StateCallCommand_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_StateCallCommand_descriptor, - new java.lang.String[] { "StateName", "Schema", }); + new java.lang.String[] { "StateName", "Schema", "Ttl", }); internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateCall_descriptor = getDescriptor().getMessageTypes().get(6); internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateCall_fieldAccessorTable = new @@ -9856,6 +11231,12 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleStat com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_SetHandleState_descriptor, new java.lang.String[] { "State", }); + internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_descriptor = + getDescriptor().getMessageTypes().get(14); + internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_descriptor, + new java.lang.String[] { "DurationMs", }); } // @@protoc_insertion_point(outer_class_scope) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 8b2fc7f5db31e..10594d6c5d340 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -72,7 +72,7 @@ public void reset() { numNulls = 0; } - if (hugeVectorThreshold > 0 && capacity > hugeVectorThreshold) { + if (hugeVectorThreshold > -1 && capacity > hugeVectorThreshold) { capacity = defaultCapacity; releaseMemory(); reserveInternal(capacity); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 53640f513fc81..b356751083fc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -21,6 +21,7 @@ import java.{lang => jl} import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.ExpressionUtils.column @@ -33,7 +34,7 @@ import org.apache.spark.sql.types._ */ @Stable final class DataFrameNaFunctions private[sql](df: DataFrame) - extends api.DataFrameNaFunctions[Dataset] { + extends api.DataFrameNaFunctions { import df.sparkSession.RichColumn protected def drop(minNonNulls: Option[Int]): Dataset[Row] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 9d7a765a24c92..78cc65bb7a298 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -24,14 +24,14 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.Partition import org.apache.spark.annotation.Stable import org.apache.spark.api.java.JavaRDD -import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, FailureSafeParser} +import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.catalyst.xml.{StaxXmlParser, XmlOptions} +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource @@ -54,136 +54,44 @@ import org.apache.spark.unsafe.types.UTF8String * @since 1.4.0 */ @Stable -class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { +class DataFrameReader private[sql](sparkSession: SparkSession) + extends api.DataFrameReader { + override type DS[U] = Dataset[U] - /** - * Specifies the input data source format. - * - * @since 1.4.0 - */ - def format(source: String): DataFrameReader = { - this.source = source - this - } + format(sparkSession.sessionState.conf.defaultDataSourceName) - /** - * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema - * automatically from data. By specifying the schema here, the underlying data source can - * skip the schema inference step, and thus speed up data loading. - * - * @since 1.4.0 - */ - def schema(schema: StructType): DataFrameReader = { - if (schema != null) { - val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] - this.userSpecifiedSchema = Option(replaced) - validateSingleVariantColumn() - } - this - } + /** @inheritdoc */ + override def format(source: String): this.type = super.format(source) - /** - * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) can - * infer the input schema automatically from data. By specifying the schema here, the underlying - * data source can skip the schema inference step, and thus speed up data loading. - * - * {{{ - * spark.read.schema("a INT, b STRING, c DOUBLE").csv("test.csv") - * }}} - * - * @since 2.3.0 - */ - def schema(schemaString: String): DataFrameReader = { - schema(StructType.fromDDL(schemaString)) - } + /** @inheritdoc */ + override def schema(schema: StructType): this.type = super.schema(schema) - /** - * Adds an input option for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. - * If a new option has the same key case-insensitively, it will override the existing option. - * - * @since 1.4.0 - */ - def option(key: String, value: String): DataFrameReader = { - this.extraOptions = this.extraOptions + (key -> value) - validateSingleVariantColumn() - this - } + /** @inheritdoc */ + override def schema(schemaString: String): this.type = super.schema(schemaString) - /** - * Adds an input option for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. - * If a new option has the same key case-insensitively, it will override the existing option. - * - * @since 2.0.0 - */ - def option(key: String, value: Boolean): DataFrameReader = option(key, value.toString) + /** @inheritdoc */ + override def option(key: String, value: String): this.type = super.option(key, value) - /** - * Adds an input option for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. - * If a new option has the same key case-insensitively, it will override the existing option. - * - * @since 2.0.0 - */ - def option(key: String, value: Long): DataFrameReader = option(key, value.toString) + /** @inheritdoc */ + override def option(key: String, value: Boolean): this.type = super.option(key, value) - /** - * Adds an input option for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. - * If a new option has the same key case-insensitively, it will override the existing option. - * - * @since 2.0.0 - */ - def option(key: String, value: Double): DataFrameReader = option(key, value.toString) + /** @inheritdoc */ + override def option(key: String, value: Long): this.type = super.option(key, value) - /** - * (Scala-specific) Adds input options for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. - * If a new option has the same key case-insensitively, it will override the existing option. - * - * @since 1.4.0 - */ - def options(options: scala.collection.Map[String, String]): DataFrameReader = { - this.extraOptions ++= options - validateSingleVariantColumn() - this - } + /** @inheritdoc */ + override def option(key: String, value: Double): this.type = super.option(key, value) - /** - * Adds input options for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. - * If a new option has the same key case-insensitively, it will override the existing option. - * - * @since 1.4.0 - */ - def options(options: java.util.Map[String, String]): DataFrameReader = { - this.options(options.asScala) - this - } + /** @inheritdoc */ + override def options(options: scala.collection.Map[String, String]): this.type = + super.options(options) - /** - * Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external - * key-value stores). - * - * @since 1.4.0 - */ - def load(): DataFrame = { - load(Seq.empty: _*) // force invocation of `load(...varargs...)` - } + /** @inheritdoc */ + override def options(options: java.util.Map[String, String]): this.type = super.options(options) - /** - * Loads input in as a `DataFrame`, for data sources that require a path (e.g. data backed by - * a local or distributed file system). - * - * @since 1.4.0 - */ + /** @inheritdoc */ + override def load(): DataFrame = load(Nil: _*) + + /** @inheritdoc */ def load(path: String): DataFrame = { // force invocation of `load(...varargs...)` if (sparkSession.sessionState.conf.legacyPathOptionBehavior) { @@ -193,12 +101,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } } - /** - * Loads input in as a `DataFrame`, for data sources that support multiple paths. - * Only works if the source is a HadoopFsRelationProvider. - * - * @since 1.6.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def load(paths: String*): DataFrame = { if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { @@ -235,90 +138,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { options = finalOptions.originalMap).resolveRelation()) } - /** - * Construct a `DataFrame` representing the database table accessible via JDBC URL - * url named table and connection properties. - * - * You can find the JDBC-specific option and parameter documentation for reading tables - * via JDBC in - * - * Data Source Option in the version you use. - * - * @since 1.4.0 - */ - def jdbc(url: String, table: String, properties: Properties): DataFrame = { - assertNoSpecifiedSchema("jdbc") - // properties should override settings in extraOptions. - this.extraOptions ++= properties.asScala - // explicit url and dbtable should override all - this.extraOptions ++= Seq(JDBCOptions.JDBC_URL -> url, JDBCOptions.JDBC_TABLE_NAME -> table) - format("jdbc").load() - } + /** @inheritdoc */ + override def jdbc(url: String, table: String, properties: Properties): DataFrame = + super.jdbc(url, table, properties) - // scalastyle:off line.size.limit - /** - * Construct a `DataFrame` representing the database table accessible via JDBC URL - * url named table. Partitions of the table will be retrieved in parallel based on the parameters - * passed to this function. - * - * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash - * your external database systems. - * - * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC in - * - * Data Source Option in the version you use. - * - * @param table Name of the table in the external database. - * @param columnName Alias of `partitionColumn` option. Refer to `partitionColumn` in - * - * Data Source Option in the version you use. - * @param connectionProperties JDBC database connection arguments, a list of arbitrary string - * tag/value. Normally at least a "user" and "password" property - * should be included. "fetchsize" can be used to control the - * number of rows per fetch and "queryTimeout" can be used to wait - * for a Statement object to execute to the given number of seconds. - * @since 1.4.0 - */ - // scalastyle:on line.size.limit - def jdbc( + /** @inheritdoc */ + override def jdbc( url: String, table: String, columnName: String, lowerBound: Long, upperBound: Long, numPartitions: Int, - connectionProperties: Properties): DataFrame = { - // columnName, lowerBound, upperBound and numPartitions override settings in extraOptions. - this.extraOptions ++= Map( - JDBCOptions.JDBC_PARTITION_COLUMN -> columnName, - JDBCOptions.JDBC_LOWER_BOUND -> lowerBound.toString, - JDBCOptions.JDBC_UPPER_BOUND -> upperBound.toString, - JDBCOptions.JDBC_NUM_PARTITIONS -> numPartitions.toString) - jdbc(url, table, connectionProperties) - } + connectionProperties: Properties): DataFrame = + super.jdbc(url, table, columnName, lowerBound, upperBound, numPartitions, connectionProperties) - /** - * Construct a `DataFrame` representing the database table accessible via JDBC URL - * url named table using connection properties. The `predicates` parameter gives a list - * expressions suitable for inclusion in WHERE clauses; each one defines one partition - * of the `DataFrame`. - * - * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash - * your external database systems. - * - * You can find the JDBC-specific option and parameter documentation for reading tables - * via JDBC in - * - * Data Source Option in the version you use. - * - * @param table Name of the table in the external database. - * @param predicates Condition in the where clause for each partition. - * @param connectionProperties JDBC database connection arguments, a list of arbitrary string - * tag/value. Normally at least a "user" and "password" property - * should be included. "fetchsize" can be used to control the - * number of rows per fetch. - * @since 1.4.0 - */ + /** @inheritdoc */ def jdbc( url: String, table: String, @@ -335,38 +170,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { sparkSession.baseRelationToDataFrame(relation) } - /** - * Loads a JSON file and returns the results as a `DataFrame`. - * - * See the documentation on the overloaded `json()` method with varargs for more details. - * - * @since 1.4.0 - */ - def json(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - json(Seq(path): _*) - } + /** @inheritdoc */ + override def json(path: String): DataFrame = super.json(path) - /** - * Loads JSON files and returns the results as a `DataFrame`. - * - * JSON Lines (newline-delimited JSON) is supported by - * default. For JSON (one record per file), set the `multiLine` option to true. - * - * This function goes through the input once to determine the input schema. If you know the - * schema in advance, use the version that specifies the schema to avoid the extra scan. - * - * You can find the JSON-specific options for reading JSON files in - * - * Data Source Option in the version you use. - * - * @since 2.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def json(paths: String*): DataFrame = { - userSpecifiedSchema.foreach(checkJsonSchema) - format("json").load(paths : _*) - } + override def json(paths: String*): DataFrame = super.json(paths: _*) /** * Loads a `JavaRDD[String]` storing JSON objects (JSON @@ -397,16 +206,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { json(sparkSession.createDataset(jsonRDD)(Encoders.STRING)) } - /** - * Loads a `Dataset[String]` storing JSON objects (JSON Lines - * text format or newline-delimited JSON) and returns the result as a `DataFrame`. - * - * Unless the schema is specified using `schema` function, this function goes through the - * input once to determine the input schema. - * - * @param jsonDataset input Dataset with one JSON object per record - * @since 2.2.0 - */ + /** @inheritdoc */ def json(jsonDataset: Dataset[String]): DataFrame = { val parsedOptions = new JSONOptions( extraOptions.toMap, @@ -439,36 +239,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = jsonDataset.isStreaming) } - /** - * Loads a CSV file and returns the result as a `DataFrame`. See the documentation on the - * other overloaded `csv()` method for more details. - * - * @since 2.0.0 - */ - def csv(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - csv(Seq(path): _*) - } + /** @inheritdoc */ + override def csv(path: String): DataFrame = super.csv(path) - /** - * Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`. - * - * If the schema is not specified using `schema` function and `inferSchema` option is enabled, - * this function goes through the input once to determine the input schema. - * - * If the schema is not specified using `schema` function and `inferSchema` option is disabled, - * it determines the columns as string types and it reads only the first line to determine the - * names and the number of fields. - * - * If the enforceSchema is set to `false`, only the CSV header in the first line is checked - * to conform specified or inferred schema. - * - * @note if `header` option is set to `true` when calling this API, all lines same with - * the header will be removed if exists. - * - * @param csvDataset input Dataset with one CSV row per record - * @since 2.2.0 - */ + /** @inheritdoc */ def csv(csvDataset: Dataset[String]): DataFrame = { val parsedOptions: CSVOptions = new CSVOptions( extraOptions.toMap, @@ -527,61 +301,18 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = csvDataset.isStreaming) } - /** - * Loads CSV files and returns the result as a `DataFrame`. - * - * This function will go through the input once to determine the input schema if `inferSchema` - * is enabled. To avoid going through the entire data once, disable `inferSchema` option or - * specify the schema explicitly using `schema`. - * - * You can find the CSV-specific options for reading CSV files in - * - * Data Source Option in the version you use. - * - * @since 2.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def csv(paths: String*): DataFrame = format("csv").load(paths : _*) + override def csv(paths: String*): DataFrame = super.csv(paths: _*) - /** - * Loads a XML file and returns the result as a `DataFrame`. See the documentation on the - * other overloaded `xml()` method for more details. - * - * @since 4.0.0 - */ - def xml(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - xml(Seq(path): _*) - } + /** @inheritdoc */ + override def xml(path: String): DataFrame = super.xml(path) - /** - * Loads XML files and returns the result as a `DataFrame`. - * - * This function will go through the input once to determine the input schema if `inferSchema` - * is enabled. To avoid going through the entire data once, disable `inferSchema` option or - * specify the schema explicitly using `schema`. - * - * You can find the XML-specific options for reading XML files in - * - * Data Source Option in the version you use. - * - * @since 4.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def xml(paths: String*): DataFrame = { - userSpecifiedSchema.foreach(checkXmlSchema) - format("xml").load(paths: _*) - } + override def xml(paths: String*): DataFrame = super.xml(paths: _*) - /** - * Loads an `Dataset[String]` storing XML object and returns the result as a `DataFrame`. - * - * If the schema is not specified using `schema` function and `inferSchema` option is enabled, - * this function goes through the input once to determine the input schema. - * - * @param xmlDataset input Dataset with one XML object per record - * @since 4.0.0 - */ + /** @inheritdoc */ def xml(xmlDataset: Dataset[String]): DataFrame = { val parsedOptions: XmlOptions = new XmlOptions( extraOptions.toMap, @@ -614,70 +345,21 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = xmlDataset.isStreaming) } - /** - * Loads a Parquet file, returning the result as a `DataFrame`. See the documentation - * on the other overloaded `parquet()` method for more details. - * - * @since 2.0.0 - */ - def parquet(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - parquet(Seq(path): _*) - } + /** @inheritdoc */ + override def parquet(path: String): DataFrame = super.parquet(path) - /** - * Loads a Parquet file, returning the result as a `DataFrame`. - * - * Parquet-specific option(s) for reading Parquet files can be found in - * - * Data Source Option in the version you use. - * - * @since 1.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def parquet(paths: String*): DataFrame = { - format("parquet").load(paths: _*) - } + override def parquet(paths: String*): DataFrame = super.parquet(paths: _*) - /** - * Loads an ORC file and returns the result as a `DataFrame`. - * - * @param path input path - * @since 1.5.0 - */ - def orc(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - orc(Seq(path): _*) - } + /** @inheritdoc */ + override def orc(path: String): DataFrame = super.orc(path) - /** - * Loads ORC files and returns the result as a `DataFrame`. - * - * ORC-specific option(s) for reading ORC files can be found in - * - * Data Source Option in the version you use. - * - * @param paths input paths - * @since 2.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def orc(paths: String*): DataFrame = format("orc").load(paths: _*) + override def orc(paths: String*): DataFrame = super.orc(paths: _*) - /** - * Returns the specified table/view as a `DataFrame`. If it's a table, it must support batch - * reading and the returned DataFrame is the batch scan query plan of this table. If it's a view, - * the returned DataFrame is simply the query plan of the view, which can either be a batch or - * streaming query plan. - * - * @param tableName is either a qualified or unqualified name that designates a table or view. - * If a database is specified, it identifies the table/view from the database. - * Otherwise, it first attempts to find a temporary view with the given name - * and then match the table/view from the current database. - * Note that, the global temporary view database is also valid here. - * @since 1.4.0 - */ + /** @inheritdoc */ def table(tableName: String): DataFrame = { assertNoSpecifiedSchema("table") val multipartIdentifier = @@ -686,108 +368,31 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { new CaseInsensitiveStringMap(extraOptions.toMap.asJava))) } - /** - * Loads text files and returns a `DataFrame` whose schema starts with a string column named - * "value", and followed by partitioned columns if there are any. See the documentation on - * the other overloaded `text()` method for more details. - * - * @since 2.0.0 - */ - def text(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - text(Seq(path): _*) - } + /** @inheritdoc */ + override def text(path: String): DataFrame = super.text(path) - /** - * Loads text files and returns a `DataFrame` whose schema starts with a string column named - * "value", and followed by partitioned columns if there are any. - * The text files must be encoded as UTF-8. - * - * By default, each line in the text files is a new row in the resulting DataFrame. For example: - * {{{ - * // Scala: - * spark.read.text("/path/to/spark/README.md") - * - * // Java: - * spark.read().text("/path/to/spark/README.md") - * }}} - * - * You can find the text-specific options for reading text files in - * - * Data Source Option in the version you use. - * - * @param paths input paths - * @since 1.6.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def text(paths: String*): DataFrame = format("text").load(paths : _*) + override def text(paths: String*): DataFrame = super.text(paths: _*) - /** - * Loads text files and returns a [[Dataset]] of String. See the documentation on the - * other overloaded `textFile()` method for more details. - * @since 2.0.0 - */ - def textFile(path: String): Dataset[String] = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - textFile(Seq(path): _*) - } + /** @inheritdoc */ + override def textFile(path: String): Dataset[String] = super.textFile(path) - /** - * Loads text files and returns a [[Dataset]] of String. The underlying schema of the Dataset - * contains a single string column named "value". - * The text files must be encoded as UTF-8. - * - * If the directory structure of the text files contains partitioning information, those are - * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. - * - * By default, each line in the text files is a new row in the resulting DataFrame. For example: - * {{{ - * // Scala: - * spark.read.textFile("/path/to/spark/README.md") - * - * // Java: - * spark.read().textFile("/path/to/spark/README.md") - * }}} - * - * You can set the text-specific options as specified in `DataFrameReader.text`. - * - * @param paths input path - * @since 2.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def textFile(paths: String*): Dataset[String] = { - assertNoSpecifiedSchema("textFile") - text(paths : _*).select("value").as[String](sparkSession.implicits.newStringEncoder) - } + override def textFile(paths: String*): Dataset[String] = super.textFile(paths: _*) - /** - * A convenient function for schema validation in APIs. - */ - private def assertNoSpecifiedSchema(operation: String): Unit = { - if (userSpecifiedSchema.nonEmpty) { - throw QueryCompilationErrors.userSpecifiedSchemaUnsupportedError(operation) - } - } - - /** - * Ensure that the `singleVariantColumn` option cannot be used if there is also a user specified - * schema. - */ - private def validateSingleVariantColumn(): Unit = { + /** @inheritdoc */ + override protected def validateSingleVariantColumn(): Unit = { if (extraOptions.get(JSONOptions.SINGLE_VARIANT_COLUMN).isDefined && userSpecifiedSchema.isDefined) { throw QueryCompilationErrors.invalidSingleVariantColumn() } } - /////////////////////////////////////////////////////////////////////////////////////// - // Builder pattern config options - /////////////////////////////////////////////////////////////////////////////////////// - - private var source: String = sparkSession.sessionState.conf.defaultDataSourceName - - private var userSpecifiedSchema: Option[StructType] = None - - private var extraOptions = CaseInsensitiveMap[String](Map.empty) + override protected def validateJsonSchema(): Unit = + userSpecifiedSchema.foreach(checkJsonSchema) + override protected def validateXmlSchema(): Unit = + userSpecifiedSchema.foreach(checkXmlSchema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index a5ab237bb7041..9f7180d8dfd6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -23,6 +23,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.execution.stat._ import org.apache.spark.sql.functions.col import org.apache.spark.util.ArrayImplicits._ @@ -34,7 +35,7 @@ import org.apache.spark.util.ArrayImplicits._ */ @Stable final class DataFrameStatFunctions private[sql](protected val df: DataFrame) - extends api.DataFrameStatFunctions[Dataset] { + extends api.DataFrameStatFunctions { /** @inheritdoc */ def approxQuantile( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala deleted file mode 100644 index 576d8276b56ef..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ /dev/null @@ -1,404 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.collection.mutable -import scala.jdk.CollectionConverters._ - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException, UnresolvedFunction, UnresolvedIdentifier, UnresolvedRelation} -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, OptionList, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, UnresolvedTableSpec} -import org.apache.spark.sql.connector.catalog.TableWritePrivilege._ -import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference, LogicalExpressions, NamedReference, Transform} -import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.types.IntegerType - -/** - * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 API. - * - * @since 3.0.0 - */ -@Experimental -final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) - extends CreateTableWriter[T] { - - private val df: DataFrame = ds.toDF() - - private val sparkSession = ds.sparkSession - import sparkSession.expression - - private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table) - - private val logicalPlan = df.queryExecution.logical - - private var provider: Option[String] = None - - private val options = new mutable.HashMap[String, String]() - - private val properties = new mutable.HashMap[String, String]() - - private var partitioning: Option[Seq[Transform]] = None - - private var clustering: Option[ClusterByTransform] = None - - override def using(provider: String): CreateTableWriter[T] = { - this.provider = Some(provider) - this - } - - override def option(key: String, value: String): DataFrameWriterV2[T] = { - this.options.put(key, value) - this - } - - override def options(options: scala.collection.Map[String, String]): DataFrameWriterV2[T] = { - options.foreach { - case (key, value) => - this.options.put(key, value) - } - this - } - - override def options(options: java.util.Map[String, String]): DataFrameWriterV2[T] = { - this.options(options.asScala) - this - } - - override def tableProperty(property: String, value: String): CreateTableWriter[T] = { - this.properties.put(property, value) - this - } - - @scala.annotation.varargs - override def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] = { - def ref(name: String): NamedReference = LogicalExpressions.parseReference(name) - - val asTransforms = (column +: columns).map(expression).map { - case PartitionTransform.YEARS(Seq(attr: Attribute)) => - LogicalExpressions.years(ref(attr.name)) - case PartitionTransform.MONTHS(Seq(attr: Attribute)) => - LogicalExpressions.months(ref(attr.name)) - case PartitionTransform.DAYS(Seq(attr: Attribute)) => - LogicalExpressions.days(ref(attr.name)) - case PartitionTransform.HOURS(Seq(attr: Attribute)) => - LogicalExpressions.hours(ref(attr.name)) - case PartitionTransform.BUCKET(Seq(Literal(numBuckets: Int, IntegerType), attr: Attribute)) => - LogicalExpressions.bucket(numBuckets, Array(ref(attr.name))) - case PartitionTransform.BUCKET(Seq(numBuckets, e)) => - throw QueryCompilationErrors.invalidBucketsNumberError(numBuckets.toString, e.toString) - case attr: Attribute => - LogicalExpressions.identity(ref(attr.name)) - case expr => - throw QueryCompilationErrors.invalidPartitionTransformationError(expr) - } - - this.partitioning = Some(asTransforms) - validatePartitioning() - this - } - - @scala.annotation.varargs - override def clusterBy(colName: String, colNames: String*): CreateTableWriter[T] = { - this.clustering = - Some(ClusterByTransform((colName +: colNames).map(col => FieldReference(col)))) - validatePartitioning() - this - } - - /** - * Validate that clusterBy is not used with partitionBy. - */ - private def validatePartitioning(): Unit = { - if (partitioning.nonEmpty && clustering.nonEmpty) { - throw QueryCompilationErrors.clusterByWithPartitionedBy() - } - } - - override def create(): Unit = { - val tableSpec = UnresolvedTableSpec( - properties = properties.toMap, - provider = provider, - optionExpression = OptionList(Seq.empty), - location = None, - comment = None, - serde = None, - external = false) - runCommand( - CreateTableAsSelect( - UnresolvedIdentifier(tableName), - partitioning.getOrElse(Seq.empty) ++ clustering, - logicalPlan, - tableSpec, - options.toMap, - false)) - } - - override def replace(): Unit = { - internalReplace(orCreate = false) - } - - override def createOrReplace(): Unit = { - internalReplace(orCreate = true) - } - - - /** - * Append the contents of the data frame to the output table. - * - * If the output table does not exist, this operation will fail with - * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be - * validated to ensure it is compatible with the existing table. - * - * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException If the table does not exist - */ - @throws(classOf[NoSuchTableException]) - def append(): Unit = { - val append = AppendData.byName( - UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT)), - logicalPlan, options.toMap) - runCommand(append) - } - - /** - * Overwrite rows matching the given filter condition with the contents of the data frame in - * the output table. - * - * If the output table does not exist, this operation will fail with - * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. - * The data frame will be validated to ensure it is compatible with the existing table. - * - * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException If the table does not exist - */ - @throws(classOf[NoSuchTableException]) - def overwrite(condition: Column): Unit = { - val overwrite = OverwriteByExpression.byName( - UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)), - logicalPlan, expression(condition), options.toMap) - runCommand(overwrite) - } - - /** - * Overwrite all partition for which the data frame contains at least one row with the contents - * of the data frame in the output table. - * - * This operation is equivalent to Hive's `INSERT OVERWRITE ... PARTITION`, which replaces - * partitions dynamically depending on the contents of the data frame. - * - * If the output table does not exist, this operation will fail with - * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be - * validated to ensure it is compatible with the existing table. - * - * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException If the table does not exist - */ - @throws(classOf[NoSuchTableException]) - def overwritePartitions(): Unit = { - val dynamicOverwrite = OverwritePartitionsDynamic.byName( - UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)), - logicalPlan, options.toMap) - runCommand(dynamicOverwrite) - } - - /** - * Wrap an action to track the QueryExecution and time cost, then report to the user-registered - * callback functions. - */ - private def runCommand(command: LogicalPlan): Unit = { - val qe = new QueryExecution(sparkSession, command, df.queryExecution.tracker) - qe.assertCommandExecuted() - } - - private def internalReplace(orCreate: Boolean): Unit = { - val tableSpec = UnresolvedTableSpec( - properties = properties.toMap, - provider = provider, - optionExpression = OptionList(Seq.empty), - location = None, - comment = None, - serde = None, - external = false) - runCommand(ReplaceTableAsSelect( - UnresolvedIdentifier(tableName), - partitioning.getOrElse(Seq.empty) ++ clustering, - logicalPlan, - tableSpec, - writeOptions = options.toMap, - orCreate = orCreate)) - } -} - -private object PartitionTransform { - class ExtractTransform(name: String) { - private val NAMES = Seq(name) - - def unapply(e: Expression): Option[Seq[Expression]] = e match { - case UnresolvedFunction(NAMES, children, false, None, false, Nil, true) => Option(children) - case _ => None - } - } - - val HOURS = new ExtractTransform("hours") - val DAYS = new ExtractTransform("days") - val MONTHS = new ExtractTransform("months") - val YEARS = new ExtractTransform("years") - val BUCKET = new ExtractTransform("bucket") -} - -/** - * Configuration methods common to create/replace operations and insert/overwrite operations. - * @tparam R builder type to return - * @since 3.0.0 - */ -trait WriteConfigMethods[R] { - /** - * Add a write option. - * - * @since 3.0.0 - */ - def option(key: String, value: String): R - - /** - * Add a boolean output option. - * - * @since 3.0.0 - */ - def option(key: String, value: Boolean): R = option(key, value.toString) - - /** - * Add a long output option. - * - * @since 3.0.0 - */ - def option(key: String, value: Long): R = option(key, value.toString) - - /** - * Add a double output option. - * - * @since 3.0.0 - */ - def option(key: String, value: Double): R = option(key, value.toString) - - /** - * Add write options from a Scala Map. - * - * @since 3.0.0 - */ - def options(options: scala.collection.Map[String, String]): R - - /** - * Add write options from a Java Map. - * - * @since 3.0.0 - */ - def options(options: java.util.Map[String, String]): R -} - -/** - * Trait to restrict calls to create and replace operations. - * - * @since 3.0.0 - */ -trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] { - /** - * Create a new table from the contents of the data frame. - * - * The new table's schema, partition layout, properties, and other configuration will be - * based on the configuration set on this writer. - * - * If the output table exists, this operation will fail with - * [[org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException]]. - * - * @throws org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException - * If the table already exists - */ - @throws(classOf[TableAlreadyExistsException]) - def create(): Unit - - /** - * Replace an existing table with the contents of the data frame. - * - * The existing table's schema, partition layout, properties, and other configuration will be - * replaced with the contents of the data frame and the configuration set on this writer. - * - * If the output table does not exist, this operation will fail with - * [[org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException]]. - * - * @throws org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException - * If the table does not exist - */ - @throws(classOf[CannotReplaceMissingTableException]) - def replace(): Unit - - /** - * Create a new table or replace an existing table with the contents of the data frame. - * - * The output table's schema, partition layout, properties, and other configuration will be based - * on the contents of the data frame and the configuration set on this writer. If the table - * exists, its configuration and data will be replaced. - */ - def createOrReplace(): Unit - - /** - * Partition the output table created by `create`, `createOrReplace`, or `replace` using - * the given columns or transforms. - * - * When specified, the table data will be stored by these values for efficient reads. - * - * For example, when a table is partitioned by day, it may be stored in a directory layout like: - *

      - *
    • `table/day=2019-06-01/`
    • - *
    • `table/day=2019-06-02/`
    • - *
    - * - * Partitioning is one of the most widely used techniques to optimize physical data layout. - * It provides a coarse-grained index for skipping unnecessary data reads when queries have - * predicates on the partitioned columns. In order for partitioning to work well, the number - * of distinct values in each column should typically be less than tens of thousands. - * - * @since 3.0.0 - */ - @scala.annotation.varargs - def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] - - /** - * Clusters the output by the given columns on the storage. The rows with matching values in - * the specified clustering columns will be consolidated within the same group. - * - * For instance, if you cluster a dataset by date, the data sharing the same date will be stored - * together in a file. This arrangement improves query efficiency when you apply selective - * filters to these clustering columns, thanks to data skipping. - * - * @since 4.0.0 - */ - @scala.annotation.varargs - def clusterBy(colName: String, colNames: String*): CreateTableWriter[T] - - /** - * Specifies a provider for the underlying output data source. Spark's default catalog supports - * "parquet", "json", etc. - * - * @since 3.0.0 - */ - def using(provider: String): CreateTableWriter[T] - - /** - * Add a table property. - */ - def tableProperty(property: String, value: String): CreateTableWriter[T] -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 38521e8e16f91..ef628ca612b49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, Query import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, ProductEncoder, StructEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} @@ -51,6 +52,7 @@ import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression @@ -60,7 +62,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, FileTable} import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions -import org.apache.spark.sql.internal.{DataFrameWriterImpl, SQLConf} +import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf} import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.TypedAggUtils.withInputType import org.apache.spark.sql.streaming.DataStreamWriter @@ -78,13 +80,14 @@ private[sql] object Dataset { val DATASET_ID_TAG = TreeNodeTag[HashSet[Long]]("dataset_id") def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = { - val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) + val encoder = implicitly[Encoder[T]] + val dataset = new Dataset(sparkSession, logicalPlan, encoder) // Eagerly bind the encoder so we verify that the encoder matches the underlying // schema. The user will get an error if this is not the case. // optimization: it is guaranteed that [[InternalRow]] can be converted to [[Row]] so // do not do this check in that case. this check can be expensive since it requires running // the whole [[Analyzer]] to resolve the deserializer - if (dataset.exprEnc.clsTag.runtimeClass != classOf[Row]) { + if (dataset.encoder.clsTag.runtimeClass != classOf[Row]) { dataset.resolvedEnc } dataset @@ -94,7 +97,7 @@ private[sql] object Dataset { sparkSession.withActive { val qe = sparkSession.sessionState.executePlan(logicalPlan) qe.assertAnalyzed() - new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema)) + new Dataset[Row](qe, RowEncoder.encoderFor(qe.analyzed.schema)) } def ofRows( @@ -105,7 +108,7 @@ private[sql] object Dataset { val qe = new QueryExecution( sparkSession, logicalPlan, shuffleCleanupMode = shuffleCleanupMode) qe.assertAnalyzed() - new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema)) + new Dataset[Row](qe, RowEncoder.encoderFor(qe.analyzed.schema)) } /** A variant of ofRows that allows passing in a tracker so we can track query parsing time. */ @@ -118,7 +121,7 @@ private[sql] object Dataset { val qe = new QueryExecution( sparkSession, logicalPlan, tracker, shuffleCleanupMode = shuffleCleanupMode) qe.assertAnalyzed() - new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema)) + new Dataset[Row](qe, RowEncoder.encoderFor(qe.analyzed.schema)) } } @@ -213,7 +216,8 @@ private[sql] object Dataset { class Dataset[T] private[sql]( @DeveloperApi @Unstable @transient val queryExecution: QueryExecution, @DeveloperApi @Unstable @transient val encoder: Encoder[T]) - extends api.Dataset[T, Dataset] { + extends api.Dataset[T] { + type DS[U] = Dataset[U] type RGD = RelationalGroupedDataset @transient lazy val sparkSession: SparkSession = { @@ -243,7 +247,7 @@ class Dataset[T] private[sql]( @transient private[sql] val logicalPlan: LogicalPlan = { val plan = queryExecution.commandExecuted - if (sparkSession.conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) { + if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) { val dsIds = plan.getTagValue(Dataset.DATASET_ID_TAG).getOrElse(new HashSet[Long]) dsIds.add(id) plan.setTagValue(Dataset.DATASET_ID_TAG, dsIds) @@ -252,12 +256,17 @@ class Dataset[T] private[sql]( } /** - * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the - * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use - * it when constructing new Dataset objects that have the same object type (that will be - * possibly resolved to a different schema). + * Expose the encoder as implicit so it can be used to construct new Dataset objects that have + * the same external type. */ - private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder) + private implicit def encoderImpl: Encoder[T] = encoder + + /** + * The actual [[ExpressionEncoder]] used by the dataset. This and its resolved counterpart should + * only be used for actual (de)serialization, the binding of Aggregator inputs, and in the rare + * cases where a plan needs to be constructed with an ExpressionEncoder. + */ + private[sql] lazy val exprEnc: ExpressionEncoder[T] = encoderFor(encoder) // The resolved `ExpressionEncoder` which can be used to turn rows to objects of type T, after // collecting rows to the driver side. @@ -265,7 +274,7 @@ class Dataset[T] private[sql]( exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer) } - private implicit def classTag: ClassTag[T] = exprEnc.clsTag + private implicit def classTag: ClassTag[T] = encoder.clsTag // sqlContext must be val because a stable identifier is expected when you import implicits @transient lazy val sqlContext: SQLContext = sparkSession.sqlContext @@ -476,7 +485,7 @@ class Dataset[T] private[sql]( /** @inheritdoc */ // This is declared with parentheses to prevent the Scala compiler from treating // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DataFrame = new Dataset[Row](queryExecution, ExpressionEncoder(schema)) + def toDF(): DataFrame = new Dataset[Row](queryExecution, RowEncoder.encoderFor(schema)) /** @inheritdoc */ def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan) @@ -671,17 +680,17 @@ class Dataset[T] private[sql]( Some(condition.expr), JoinHint.NONE)).analyzed.asInstanceOf[Join] - implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder - .tuple(Seq(this.exprEnc, other.exprEnc), useNullSafeDeserializer = true) - .asInstanceOf[Encoder[(T, U)]] - - withTypedPlan(JoinWith.typedJoinWith( + val leftEncoder = agnosticEncoderFor(encoder) + val rightEncoder = agnosticEncoderFor(other.encoder) + val joinEncoder = ProductEncoder.tuple(Seq(leftEncoder, rightEncoder), elementsCanBeNull = true) + .asInstanceOf[Encoder[(T, U)]] + val joinWith = JoinWith.typedJoinWith( joined, sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity, sparkSession.sessionState.analyzer.resolver, - this.exprEnc.isSerializedAsStructForTopLevel, - other.exprEnc.isSerializedAsStructForTopLevel)) + leftEncoder.isStruct, + rightEncoder.isStruct) + new Dataset(sparkSession, joinWith, joinEncoder) } // TODO(SPARK-22947): Fix the DataFrame API. @@ -772,7 +781,7 @@ class Dataset[T] private[sql]( private def addDataFrameIdToCol(expr: NamedExpression): NamedExpression = { val newExpr = expr transform { case a: AttributeReference - if sparkSession.conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) => + if sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) => val metadata = new MetadataBuilder() .withMetadata(a.metadata) .putLong(Dataset.DATASET_ID_KEY, id) @@ -826,24 +835,29 @@ class Dataset[T] private[sql]( /** @inheritdoc */ def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { - implicit val encoder: ExpressionEncoder[U1] = encoderFor(c1.encoder) + val encoder = agnosticEncoderFor(c1.encoder) val tc1 = withInputType(c1.named, exprEnc, logicalPlan.output) val project = Project(tc1 :: Nil, logicalPlan) - if (!encoder.isSerializedAsStructForTopLevel) { - new Dataset[U1](sparkSession, project, encoder) - } else { - // Flattens inner fields of U1 - new Dataset[Tuple1[U1]](sparkSession, project, ExpressionEncoder.tuple(encoder)).map(_._1) + val plan = encoder match { + case se: StructEncoder[U1] => + // Flatten the result. + val attribute = GetColumnByOrdinal(0, se.dataType) + val projectList = se.fields.zipWithIndex.map { + case (field, index) => + Alias(GetStructField(attribute, index, None), field.name)() + } + Project(projectList, project) + case _ => project } + new Dataset[U1](sparkSession, plan, encoder) } /** @inheritdoc */ protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - val encoders = columns.map(c => encoderFor(c.encoder)) + val encoders = columns.map(c => agnosticEncoderFor(c.encoder)) val namedColumns = columns.map(c => withInputType(c.named, exprEnc, logicalPlan.output)) - val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) - new Dataset(execution, ExpressionEncoder.tuple(encoders)) + new Dataset(sparkSession, Project(namedColumns, logicalPlan), ProductEncoder.tuple(encoders)) } /** @inheritdoc */ @@ -851,24 +865,7 @@ class Dataset[T] private[sql]( Filter(condition.expr, logicalPlan) } - /** - * Groups the Dataset using the specified columns, so we can run aggregation on them. See - * [[RelationalGroupedDataset]] for all the available aggregate functions. - * - * {{{ - * // Compute the average for all numeric columns grouped by department. - * ds.groupBy($"department").avg() - * - * // Compute the max age and average salary, grouped by department and gender. - * ds.groupBy($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 2.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def groupBy(cols: Column*): RelationalGroupedDataset = { RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType) @@ -900,35 +897,19 @@ class Dataset[T] private[sql]( rdd.reduce(func) } - /** - * (Scala-specific) - * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. - * - * @group typedrel - * @since 2.0.0 - */ + /** @inheritdoc */ def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { val withGroupingKey = AppendColumns(func, logicalPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) new KeyValueGroupedDataset( - encoderFor[K], - encoderFor[T], + implicitly[Encoder[K]], + encoder, executed, logicalPlan.output, withGroupingKey.newColumns) } - /** - * (Java-specific) - * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. - * - * @group typedrel - * @since 2.0.0 - */ - def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = - groupByKey(func.call(_))(encoder) - /** @inheritdoc */ def unpivot( ids: Array[Column], @@ -981,6 +962,22 @@ class Dataset[T] private[sql]( valueColumnName: String): DataFrame = unpivot(ids.toArray, variableColumnName, valueColumnName) + /** @inheritdoc */ + def transpose(indexColumn: Column): DataFrame = withPlan { + UnresolvedTranspose( + Seq(indexColumn.named), + logicalPlan + ) + } + + /** @inheritdoc */ + def transpose(): DataFrame = withPlan { + UnresolvedTranspose( + Seq.empty, + logicalPlan + ) + } + /** @inheritdoc */ @scala.annotation.varargs def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = withTypedPlan { @@ -1362,12 +1359,6 @@ class Dataset[T] private[sql]( implicitly[Encoder[U]]) } - /** @inheritdoc */ - def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { - val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala - mapPartitions(func)(encoder) - } - /** * Returns a new `DataFrame` that contains the result of applying a serialized R function * `func` to each partition. @@ -1377,7 +1368,11 @@ class Dataset[T] private[sql]( packageNames: Array[Byte], broadcastVars: Array[Broadcast[Object]], schema: StructType): DataFrame = { - val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]] + val rowEncoder: ExpressionEncoder[Row] = if (isUnTyped) { + exprEnc.asInstanceOf[ExpressionEncoder[Row]] + } else { + ExpressionEncoder(schema) + } Dataset.ofRows( sparkSession, MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan)) @@ -1426,11 +1421,6 @@ class Dataset[T] private[sql]( Option(profile))) } - /** @inheritdoc */ - def foreach(f: T => Unit): Unit = withNewRDDExecutionId("foreach") { - rdd.foreach(f) - } - /** @inheritdoc */ def foreachPartition(f: Iterator[T] => Unit): Unit = withNewRDDExecutionId("foreachPartition") { rdd.foreachPartition(f) @@ -1606,25 +1596,7 @@ class Dataset[T] private[sql]( new DataFrameWriterImpl[T](this) } - /** - * Create a write configuration builder for v2 sources. - * - * This builder is used to configure and execute write operations. For example, to append to an - * existing table, run: - * - * {{{ - * df.writeTo("catalog.db.table").append() - * }}} - * - * This can also be used to create or replace existing tables: - * - * {{{ - * df.writeTo("catalog.db.table").partitionedBy($"col").createOrReplace() - * }}} - * - * @group basic - * @since 3.0.0 - */ + /** @inheritdoc */ def writeTo(table: String): DataFrameWriterV2[T] = { // TODO: streaming could be adapted to use this interface if (isStreaming) { @@ -1632,31 +1604,10 @@ class Dataset[T] private[sql]( errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", messageParameters = Map("methodName" -> toSQLId("writeTo"))) } - new DataFrameWriterV2[T](table, this) + new DataFrameWriterV2Impl[T](table, this) } - /** - * Merges a set of updates, insertions, and deletions based on a source table into - * a target table. - * - * Scala Examples: - * {{{ - * spark.table("source") - * .mergeInto("target", $"source.id" === $"target.id") - * .whenMatched($"salary" === 100) - * .delete() - * .whenNotMatched() - * .insertAll() - * .whenNotMatchedBySource($"salary" === 100) - * .update(Map( - * "salary" -> lit(200) - * )) - * .merge() - * }}} - * - * @group basic - * @since 4.0.0 - */ + /** @inheritdoc */ def mergeInto(table: String, condition: Column): MergeIntoWriter[T] = { if (isStreaming) { logicalPlan.failAnalysis( @@ -1664,7 +1615,7 @@ class Dataset[T] private[sql]( messageParameters = Map("methodName" -> toSQLId("mergeInto"))) } - new MergeIntoWriter[T](table, this, condition) + new MergeIntoWriterImpl[T](table, this, condition) } /** @@ -1684,7 +1635,7 @@ class Dataset[T] private[sql]( /** @inheritdoc */ override def toJSON: Dataset[String] = { - val rowSchema = this.schema + val rowSchema = exprEnc.schema val sessionLocalTimeZone = sparkSession.sessionState.conf.sessionLocalTimeZone mapPartitions { iter => val writer = new CharArrayWriter() @@ -1952,6 +1903,10 @@ class Dataset[T] private[sql]( override def dropDuplicatesWithinWatermark(col1: String, cols: String*): Dataset[T] = super.dropDuplicatesWithinWatermark(col1, cols: _*) + /** @inheritdoc */ + override def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = + super.mapPartitions(f, encoder) + /** @inheritdoc */ override def flatMap[U: Encoder](func: T => IterableOnce[U]): Dataset[U] = super.flatMap(func) @@ -1959,9 +1914,6 @@ class Dataset[T] private[sql]( override def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = super.flatMap(f, encoder) - /** @inheritdoc */ - override def foreach(func: ForeachFunction[T]): Unit = super.foreach(func) - /** @inheritdoc */ override def foreachPartition(func: ForeachPartitionFunction[T]): Unit = super.foreachPartition(func) @@ -2018,18 +1970,16 @@ class Dataset[T] private[sql]( @scala.annotation.varargs override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) + /** @inheritdoc */ + override def groupByKey[K]( + func: MapFunction[T, K], + encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = + super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]] + //////////////////////////////////////////////////////////////////////////// // For Python API //////////////////////////////////////////////////////////////////////////// - /** - * It adds a new long column with the name `name` that increases one by one. - * This is for 'distributed-sequence' default index in pandas API on Spark. - */ - private[sql] def withSequenceColumn(name: String) = { - select(column(DistributedSequenceID()).alias(name), col("*")) - } - /** * Converts a JavaRDD to a PythonRDD. */ @@ -2257,7 +2207,7 @@ class Dataset[T] private[sql]( /** A convenient function to wrap a set based logical plan and produce a Dataset. */ @inline private def withSetOperator[U : Encoder](logicalPlan: LogicalPlan): Dataset[U] = { - if (classTag.runtimeClass.isAssignableFrom(classOf[Row])) { + if (isUnTyped) { // Set operators widen types (change the schema), so we cannot reuse the row encoder. Dataset.ofRows(sparkSession, logicalPlan).asInstanceOf[Dataset[U]] } else { @@ -2265,6 +2215,8 @@ class Dataset[T] private[sql]( } } + private def isUnTyped: Boolean = classTag.runtimeClass.isAssignableFrom(classOf[Row]) + /** Returns a optimized plan for CommandResult, convert to `LocalRelation`. */ private def commandResultOptimized: Dataset[T] = { logicalPlan match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala deleted file mode 100644 index 27012c471462d..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -/** - * The abstract class for writing custom logic to process data generated by a query. - * This is often used to write the output of a streaming query to arbitrary storage systems. - * Any implementation of this base class will be used by Spark in the following way. - * - *
      - *
    • A single instance of this class is responsible of all the data generated by a single task - * in a query. In other words, one instance is responsible for processing one partition of the - * data generated in a distributed manner. - * - *
    • Any implementation of this class must be serializable because each task will get a fresh - * serialized-deserialized copy of the provided object. Hence, it is strongly recommended that - * any initialization for writing data (e.g. opening a connection or starting a transaction) - * is done after the `open(...)` method has been called, which signifies that the task is - * ready to generate data. - * - *
    • The lifecycle of the methods are as follows. - * - *
      - *   For each partition with `partitionId`:
      - *       For each batch/epoch of streaming data (if its streaming query) with `epochId`:
      - *           Method `open(partitionId, epochId)` is called.
      - *           If `open` returns true:
      - *                For each row in the partition and batch/epoch, method `process(row)` is called.
      - *           Method `close(errorOrNull)` is called with error (if any) seen while processing rows.
      - *   
      - * - *
    - * - * Important points to note: - *
      - *
    • Spark doesn't guarantee same output for (partitionId, epochId), so deduplication - * cannot be achieved with (partitionId, epochId). e.g. source provides different number of - * partitions for some reason, Spark optimization changes number of partitions, etc. - * Refer SPARK-28650 for more details. If you need deduplication on output, try out - * `foreachBatch` instead. - * - *
    • The `close()` method will be called if `open()` method returns successfully (irrespective - * of the return value), except if the JVM crashes in the middle. - *
    - * - * Scala example: - * {{{ - * datasetOfString.writeStream.foreach(new ForeachWriter[String] { - * - * def open(partitionId: Long, version: Long): Boolean = { - * // open connection - * } - * - * def process(record: String) = { - * // write string to connection - * } - * - * def close(errorOrNull: Throwable): Unit = { - * // close the connection - * } - * }) - * }}} - * - * Java example: - * {{{ - * datasetOfString.writeStream().foreach(new ForeachWriter() { - * - * @Override - * public boolean open(long partitionId, long version) { - * // open connection - * } - * - * @Override - * public void process(String value) { - * // write string to connection - * } - * - * @Override - * public void close(Throwable errorOrNull) { - * // close the connection - * } - * }); - * }}} - * - * @since 2.0.0 - */ -abstract class ForeachWriter[T] extends Serializable { - - // TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API. - - /** - * Called when starting to process one partition of new data in the executor. See the class - * docs for more information on how to use the `partitionId` and `epochId`. - * - * @param partitionId the partition id. - * @param epochId a unique id for data deduplication. - * @return `true` if the corresponding partition and version id should be processed. `false` - * indicates the partition should be skipped. - */ - def open(partitionId: Long, epochId: Long): Boolean - - /** - * Called to process the data in the executor side. This method will be called only if `open` - * returns `true`. - */ - def process(value: T): Unit - - /** - * Called when stopping to process one partition of new data in the executor side. This is - * guaranteed to be called either `open` returns `true` or `false`. However, - * `close` won't be called in the following cases: - * - *
      - *
    • JVM crashes without throwing a `Throwable`
    • - *
    • `open` throws a `Throwable`.
    • - *
    - * - * @param errorOrNull the error thrown during processing data or null if there was no error. - */ - def close(errorOrNull: Throwable): Unit -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index e3ea33a7504bf..c645ba57e8f82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql -import scala.jdk.CollectionConverters._ - import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.analysis.{EliminateEventTimeWatermark, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, ProductEncoder} +import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator import org.apache.spark.sql.internal.TypedAggUtils.{aggKeyColumn, withInputType} @@ -41,79 +41,41 @@ class KeyValueGroupedDataset[K, V] private[sql]( vEncoder: Encoder[V], @transient val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], - private val groupingAttributes: Seq[Attribute]) extends Serializable { + private val groupingAttributes: Seq[Attribute]) + extends api.KeyValueGroupedDataset[K, V] { + type KVDS[KY, VL] = KeyValueGroupedDataset[KY, VL] - // Similar to [[Dataset]], we turn the passed in encoder to `ExpressionEncoder` explicitly. - private implicit val kExprEnc: ExpressionEncoder[K] = encoderFor(kEncoder) - private implicit val vExprEnc: ExpressionEncoder[V] = encoderFor(vEncoder) + private implicit def kEncoderImpl: Encoder[K] = kEncoder + private implicit def vEncoderImpl: Encoder[V] = vEncoder private def logicalPlan = queryExecution.analyzed private def sparkSession = queryExecution.sparkSession import queryExecution.sparkSession._ - /** - * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the - * specified type. The mapping of key columns to the type follows the same rules as `as` on - * [[Dataset]]. - * - * @since 1.6.0 - */ + /** @inheritdoc */ def keyAs[L : Encoder]: KeyValueGroupedDataset[L, V] = new KeyValueGroupedDataset( - encoderFor[L], - vExprEnc, + implicitly[Encoder[L]], + vEncoder, queryExecution, dataAttributes, groupingAttributes) - /** - * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied - * to the data. The grouping key is unchanged by this. - * - * {{{ - * // Create values grouped by key from a Dataset[(K, V)] - * ds.groupByKey(_._1).mapValues(_._2) // Scala - * }}} - * - * @since 2.1.0 - */ + /** @inheritdoc */ def mapValues[W : Encoder](func: V => W): KeyValueGroupedDataset[K, W] = { val withNewData = AppendColumns(func, dataAttributes, logicalPlan) val projected = Project(withNewData.newColumns ++ groupingAttributes, withNewData) val executed = sparkSession.sessionState.executePlan(projected) new KeyValueGroupedDataset( - encoderFor[K], - encoderFor[W], + kEncoder, + implicitly[Encoder[W]], executed, withNewData.newColumns, groupingAttributes) } - /** - * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied - * to the data. The grouping key is unchanged by this. - * - * {{{ - * // Create Integer values grouped by String key from a Dataset> - * Dataset> ds = ...; - * KeyValueGroupedDataset grouped = - * ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT()); - * }}} - * - * @since 2.1.0 - */ - def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = { - implicit val uEnc = encoder - mapValues { (v: V) => func.call(v) } - } - - /** - * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping - * over the Dataset to extract the keys and then running a distinct operation on those. - * - * @since 1.6.0 - */ + /** @inheritdoc */ def keys: Dataset[K] = { Dataset[K]( sparkSession, @@ -121,194 +83,23 @@ class KeyValueGroupedDataset[K, V] private[sql]( Project(groupingAttributes, logicalPlan))) } - /** - * (Scala-specific) - * Applies the given function to each group of data. For each unique group, the function will - * be passed the group key and an iterator that contains all of the elements in the group. The - * function can return an iterator containing elements of an arbitrary type which will be returned - * as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the memory - * constraints of their cluster. - * - * @since 1.6.0 - */ - def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = { - Dataset[U]( - sparkSession, - MapGroups( - f, - groupingAttributes, - dataAttributes, - Seq.empty, - logicalPlan)) - } - - /** - * (Java-specific) - * Applies the given function to each group of data. For each unique group, the function will - * be passed the group key and an iterator that contains all of the elements in the group. The - * function can return an iterator containing elements of an arbitrary type which will be returned - * as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the memory - * constraints of their cluster. - * - * @since 1.6.0 - */ - def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - flatMapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder) - } - - /** - * (Scala-specific) - * Applies the given function to each group of data. For each unique group, the function will - * be passed the group key and a sorted iterator that contains all of the elements in the group. - * The function can return an iterator containing elements of an arbitrary type which will be - * returned as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the memory - * constraints of their cluster. - * - * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator - * to be sorted according to the given sort expressions. That sorting does not add - * computational complexity. - * - * @see [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]] - * @since 3.4.0 - */ + /** @inheritdoc */ def flatMapSortedGroups[U : Encoder]( sortExprs: Column*)( f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = { - val sortOrder: Seq[SortOrder] = MapGroups.sortOrder(sortExprs.map(_.expr)) - Dataset[U]( sparkSession, MapGroups( f, groupingAttributes, dataAttributes, - sortOrder, + MapGroups.sortOrder(sortExprs.map(_.expr)), logicalPlan ) ) } - /** - * (Java-specific) - * Applies the given function to each group of data. For each unique group, the function will - * be passed the group key and a sorted iterator that contains all of the elements in the group. - * The function can return an iterator containing elements of an arbitrary type which will be - * returned as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the memory - * constraints of their cluster. - * - * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator - * to be sorted according to the given sort expressions. That sorting does not add - * computational complexity. - * - * @see [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]] - * @since 3.4.0 - */ - def flatMapSortedGroups[U]( - SortExprs: Array[Column], - f: FlatMapGroupsFunction[K, V, U], - encoder: Encoder[U]): Dataset[U] = { - import org.apache.spark.util.ArrayImplicits._ - flatMapSortedGroups( - SortExprs.toImmutableArraySeq: _*)((key, data) => f.call(key, data.asJava).asScala)(encoder) - } - - /** - * (Scala-specific) - * Applies the given function to each group of data. For each unique group, the function will - * be passed the group key and an iterator that contains all of the elements in the group. The - * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the memory - * constraints of their cluster. - * - * @since 1.6.0 - */ - def mapGroups[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { - val func = (key: K, it: Iterator[V]) => Iterator(f(key, it)) - flatMapGroups(func) - } - - /** - * (Java-specific) - * Applies the given function to each group of data. For each unique group, the function will - * be passed the group key and an iterator that contains all of the elements in the group. The - * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the memory - * constraints of their cluster. - * - * @since 1.6.0 - */ - def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - mapGroups((key, data) => f.call(key, data.asJava))(encoder) - } - - /** - * (Scala-specific) - * Applies the given function to each group of data, while maintaining a user-defined per-group - * state. The result Dataset will represent the objects returned by the function. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger, and - * updates to each group's state will be saved across invocations. - * See [[org.apache.spark.sql.streaming.GroupState]] for more details. - * - * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.2.0 - */ + /** @inheritdoc */ def mapGroupsWithState[S: Encoder, U: Encoder]( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s)) @@ -324,23 +115,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( child = logicalPlan)) } - /** - * (Scala-specific) - * Applies the given function to each group of data, while maintaining a user-defined per-group - * state. The result Dataset will represent the objects returned by the function. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger, and - * updates to each group's state will be saved across invocations. - * See [[org.apache.spark.sql.streaming.GroupState]] for more details. - * - * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param timeoutConf Timeout configuration for groups that do not receive data for a while. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.2.0 - */ + /** @inheritdoc */ def mapGroupsWithState[S: Encoder, U: Encoder]( timeoutConf: GroupStateTimeout)( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { @@ -357,29 +132,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( child = logicalPlan)) } - /** - * (Scala-specific) - * Applies the given function to each group of data, while maintaining a user-defined per-group - * state. The result Dataset will represent the objects returned by the function. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger, and - * updates to each group's state will be saved across invocations. - * See [[org.apache.spark.sql.streaming.GroupState]] for more details. - * - * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param timeoutConf Timeout Conf, see GroupStateTimeout for more details - * @param initialState The user provided state that will be initialized when the first batch - * of data is processed in the streaming query. The user defined function - * will be called on the state data even if there are no other values in - * the group. To convert a Dataset ds of type Dataset[(K, S)] to a - * KeyValueGroupedDataset[K, S] - * do {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}} - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.2.0 - */ + /** @inheritdoc */ def mapGroupsWithState[S: Encoder, U: Encoder]( timeoutConf: GroupStateTimeout, initialState: KeyValueGroupedDataset[K, S])( @@ -402,114 +155,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( )) } - /** - * (Java-specific) - * Applies the given function to each group of data, while maintaining a user-defined per-group - * state. The result Dataset will represent the objects returned by the function. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger, and - * updates to each group's state will be saved across invocations. - * See `GroupState` for more details. - * - * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param stateEncoder Encoder for the state type. - * @param outputEncoder Encoder for the output type. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.2.0 - */ - def mapGroupsWithState[S, U]( - func: MapGroupsWithStateFunction[K, V, S, U], - stateEncoder: Encoder[S], - outputEncoder: Encoder[U]): Dataset[U] = { - mapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s) - )(stateEncoder, outputEncoder) - } - - /** - * (Java-specific) - * Applies the given function to each group of data, while maintaining a user-defined per-group - * state. The result Dataset will represent the objects returned by the function. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger, and - * updates to each group's state will be saved across invocations. - * See `GroupState` for more details. - * - * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param stateEncoder Encoder for the state type. - * @param outputEncoder Encoder for the output type. - * @param timeoutConf Timeout configuration for groups that do not receive data for a while. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.2.0 - */ - def mapGroupsWithState[S, U]( - func: MapGroupsWithStateFunction[K, V, S, U], - stateEncoder: Encoder[S], - outputEncoder: Encoder[U], - timeoutConf: GroupStateTimeout): Dataset[U] = { - mapGroupsWithState[S, U](timeoutConf)( - (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s) - )(stateEncoder, outputEncoder) - } - - /** - * (Java-specific) - * Applies the given function to each group of data, while maintaining a user-defined per-group - * state. The result Dataset will represent the objects returned by the function. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger, and - * updates to each group's state will be saved across invocations. - * See `GroupState` for more details. - * - * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param stateEncoder Encoder for the state type. - * @param outputEncoder Encoder for the output type. - * @param timeoutConf Timeout configuration for groups that do not receive data for a while. - * @param initialState The user provided state that will be initialized when the first batch - * of data is processed in the streaming query. The user defined function - * will be called on the state data even if there are no other values in - * the group. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.2.0 - */ - def mapGroupsWithState[S, U]( - func: MapGroupsWithStateFunction[K, V, S, U], - stateEncoder: Encoder[S], - outputEncoder: Encoder[U], - timeoutConf: GroupStateTimeout, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { - mapGroupsWithState[S, U](timeoutConf, initialState)( - (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s) - )(stateEncoder, outputEncoder) - } - - /** - * (Scala-specific) - * Applies the given function to each group of data, while maintaining a user-defined per-group - * state. The result Dataset will represent the objects returned by the function. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger, and - * updates to each group's state will be saved across invocations. - * See `GroupState` for more details. - * - * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param outputMode The output mode of the function. - * @param timeoutConf Timeout configuration for groups that do not receive data for a while. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.2.0 - */ + /** @inheritdoc */ def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, timeoutConf: GroupStateTimeout)( @@ -529,29 +175,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( child = logicalPlan)) } - /** - * (Scala-specific) - * Applies the given function to each group of data, while maintaining a user-defined per-group - * state. The result Dataset will represent the objects returned by the function. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger, and - * updates to each group's state will be saved across invocations. - * See `GroupState` for more details. - * - * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param outputMode The output mode of the function. - * @param timeoutConf Timeout configuration for groups that do not receive data for a while. - * @param initialState The user provided state that will be initialized when the first batch - * of data is processed in the streaming query. The user defined function - * will be called on the state data even if there are no other values in - * the group. To covert a Dataset `ds` of type of type `Dataset[(K, S)]` - * to a `KeyValueGroupedDataset[K, S]`, use - * {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}} - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.2.0 - */ + /** @inheritdoc */ def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, timeoutConf: GroupStateTimeout, @@ -576,91 +200,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( )) } - /** - * (Java-specific) - * Applies the given function to each group of data, while maintaining a user-defined per-group - * state. The result Dataset will represent the objects returned by the function. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger, and - * updates to each group's state will be saved across invocations. - * See `GroupState` for more details. - * - * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param outputMode The output mode of the function. - * @param stateEncoder Encoder for the state type. - * @param outputEncoder Encoder for the output type. - * @param timeoutConf Timeout configuration for groups that do not receive data for a while. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.2.0 - */ - def flatMapGroupsWithState[S, U]( - func: FlatMapGroupsWithStateFunction[K, V, S, U], - outputMode: OutputMode, - stateEncoder: Encoder[S], - outputEncoder: Encoder[U], - timeoutConf: GroupStateTimeout): Dataset[U] = { - val f = (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s).asScala - flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder) - } - - /** - * (Java-specific) - * Applies the given function to each group of data, while maintaining a user-defined per-group - * state. The result Dataset will represent the objects returned by the function. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger, and - * updates to each group's state will be saved across invocations. - * See `GroupState` for more details. - * - * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param outputMode The output mode of the function. - * @param stateEncoder Encoder for the state type. - * @param outputEncoder Encoder for the output type. - * @param timeoutConf Timeout configuration for groups that do not receive data for a while. - * @param initialState The user provided state that will be initialized when the first batch - * of data is processed in the streaming query. The user defined function - * will be called on the state data even if there are no other values in - * the group. To covert a Dataset `ds` of type of type `Dataset[(K, S)]` - * to a `KeyValueGroupedDataset[K, S]`, use - * {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}} - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.2.0 - */ - def flatMapGroupsWithState[S, U]( - func: FlatMapGroupsWithStateFunction[K, V, S, U], - outputMode: OutputMode, - stateEncoder: Encoder[S], - outputEncoder: Encoder[U], - timeoutConf: GroupStateTimeout, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { - val f = (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s).asScala - flatMapGroupsWithState[S, U]( - outputMode, timeoutConf, initialState)(f)(stateEncoder, outputEncoder) - } - - /** - * (Scala-specific) - * Invokes methods defined in the stateful processor used in arbitrary state API v2. - * We allow the user to act on per-group set of input rows along with keyed state and the - * user can choose to output/return 0 or more rows. - * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows - * in each trigger and the user's state/state variables will be stored persistently across - * invocations. - * - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked - * by the operator. - * @param timeMode The time mode semantics of the stateful processor for timers and TTL. - * @param outputMode The output mode of the stateful processor. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - */ + /** @inheritdoc */ private[sql] def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], timeMode: TimeMode, @@ -678,29 +218,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( ) } - /** - * (Scala-specific) - * Invokes methods defined in the stateful processor used in arbitrary state API v2. - * We allow the user to act on per-group set of input rows along with keyed state and the - * user can choose to output/return 0 or more rows. - * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows - * in each trigger and the user's state/state variables will be stored persistently across - * invocations. - * - * Downstream operators would use specified eventTimeColumnName to calculate watermark. - * Note that TimeMode is set to EventTime to ensure correct flow of watermark. - * - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param statefulProcessor Instance of statefulProcessor whose functions will - * be invoked by the operator. - * @param eventTimeColumnName eventTime column in the output dataset. Any operations after - * transformWithState will use the new eventTimeColumn. The user - * needs to ensure that the eventTime for emitted output adheres to - * the watermark boundary, otherwise streaming query will fail. - * @param outputMode The output mode of the stateful processor. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - */ + /** @inheritdoc */ private[sql] def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], eventTimeColumnName: String, @@ -716,81 +234,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( updateEventTimeColumnAfterTransformWithState(transformWithState, eventTimeColumnName) } - /** - * (Java-specific) - * Invokes methods defined in the stateful processor used in arbitrary state API v2. - * We allow the user to act on per-group set of input rows along with keyed state and the - * user can choose to output/return 0 or more rows. - * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows - * in each trigger and the user's state/state variables will be stored persistently across - * invocations. - * - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked by the - * operator. - * @param timeMode The time mode semantics of the stateful processor for timers and TTL. - * @param outputMode The output mode of the stateful processor. - * @param outputEncoder Encoder for the output type. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - */ - private[sql] def transformWithState[U: Encoder]( - statefulProcessor: StatefulProcessor[K, V, U], - timeMode: TimeMode, - outputMode: OutputMode, - outputEncoder: Encoder[U]): Dataset[U] = { - transformWithState(statefulProcessor, timeMode, outputMode)(outputEncoder) - } - - /** - * (Java-specific) - * Invokes methods defined in the stateful processor used in arbitrary state API v2. - * We allow the user to act on per-group set of input rows along with keyed state and the - * user can choose to output/return 0 or more rows. - * - * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows - * in each trigger and the user's state/state variables will be stored persistently across - * invocations. - * - * Downstream operators would use specified eventTimeColumnName to calculate watermark. - * Note that TimeMode is set to EventTime to ensure correct flow of watermark. - * - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked by the - * operator. - * @param eventTimeColumnName eventTime column in the output dataset. Any operations after - * transformWithState will use the new eventTimeColumn. The user - * needs to ensure that the eventTime for emitted output adheres to - * the watermark boundary, otherwise streaming query will fail. - * @param outputMode The output mode of the stateful processor. - * @param outputEncoder Encoder for the output type. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - */ - private[sql] def transformWithState[U: Encoder]( - statefulProcessor: StatefulProcessor[K, V, U], - eventTimeColumnName: String, - outputMode: OutputMode, - outputEncoder: Encoder[U]): Dataset[U] = { - transformWithState(statefulProcessor, eventTimeColumnName, outputMode)(outputEncoder) - } - - /** - * (Scala-specific) - * Invokes methods defined in the stateful processor used in arbitrary state API v2. - * Functions as the function above, but with additional initial state. - * - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @tparam S The type of initial state objects. Must be encodable to Spark SQL types. - * @param statefulProcessor Instance of statefulProcessor whose functions will - * be invoked by the operator. - * @param timeMode The time mode semantics of the stateful processor for timers and TTL. - * @param outputMode The output mode of the stateful processor. - * @param initialState User provided initial state that will be used to initiate state for - * the query in the first batch. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - */ + /** @inheritdoc */ private[sql] def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], timeMode: TimeMode, @@ -812,29 +256,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( ) } - /** - * (Scala-specific) - * Invokes methods defined in the stateful processor used in arbitrary state API v2. - * Functions as the function above, but with additional eventTimeColumnName for output. - * - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @tparam S The type of initial state objects. Must be encodable to Spark SQL types. - * - * Downstream operators would use specified eventTimeColumnName to calculate watermark. - * Note that TimeMode is set to EventTime to ensure correct flow of watermark. - * - * @param statefulProcessor Instance of statefulProcessor whose functions will - * be invoked by the operator. - * @param eventTimeColumnName eventTime column in the output dataset. Any operations after - * transformWithState will use the new eventTimeColumn. The user - * needs to ensure that the eventTime for emitted output adheres to - * the watermark boundary, otherwise streaming query will fail. - * @param outputMode The output mode of the stateful processor. - * @param initialState User provided initial state that will be used to initiate state for - * the query in the first batch. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - */ + /** @inheritdoc */ private[sql] def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], eventTimeColumnName: String, @@ -855,71 +277,6 @@ class KeyValueGroupedDataset[K, V] private[sql]( updateEventTimeColumnAfterTransformWithState(transformWithState, eventTimeColumnName) } - /** - * (Java-specific) - * Invokes methods defined in the stateful processor used in arbitrary state API v2. - * Functions as the function above, but with additional initialStateEncoder for state encoding. - * - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @tparam S The type of initial state objects. Must be encodable to Spark SQL types. - * @param statefulProcessor Instance of statefulProcessor whose functions will - * be invoked by the operator. - * @param timeMode The time mode semantics of the stateful processor for - * timers and TTL. - * @param outputMode The output mode of the stateful processor. - * @param initialState User provided initial state that will be used to initiate state for - * the query in the first batch. - * @param outputEncoder Encoder for the output type. - * @param initialStateEncoder Encoder for the initial state type. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - */ - private[sql] def transformWithState[U: Encoder, S: Encoder]( - statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], - timeMode: TimeMode, - outputMode: OutputMode, - initialState: KeyValueGroupedDataset[K, S], - outputEncoder: Encoder[U], - initialStateEncoder: Encoder[S]): Dataset[U] = { - transformWithState(statefulProcessor, timeMode, - outputMode, initialState)(outputEncoder, initialStateEncoder) - } - - /** - * (Java-specific) - * Invokes methods defined in the stateful processor used in arbitrary state API v2. - * Functions as the function above, but with additional eventTimeColumnName for output. - * - * Downstream operators would use specified eventTimeColumnName to calculate watermark. - * Note that TimeMode is set to EventTime to ensure correct flow of watermark. - * - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @tparam S The type of initial state objects. Must be encodable to Spark SQL types. - * @param statefulProcessor Instance of statefulProcessor whose functions will - * be invoked by the operator. - * @param outputMode The output mode of the stateful processor. - * @param initialState User provided initial state that will be used to initiate state for - * the query in the first batch. - * @param eventTimeColumnName event column in the output dataset. Any operations after - * transformWithState will use the new eventTimeColumn. The user - * needs to ensure that the eventTime for emitted output adheres to - * the watermark boundary, otherwise streaming query will fail. - * @param outputEncoder Encoder for the output type. - * @param initialStateEncoder Encoder for the initial state type. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - */ - private[sql] def transformWithState[U: Encoder, S: Encoder]( - statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], - outputMode: OutputMode, - initialState: KeyValueGroupedDataset[K, S], - eventTimeColumnName: String, - outputEncoder: Encoder[U], - initialStateEncoder: Encoder[S]): Dataset[U] = { - transformWithState(statefulProcessor, eventTimeColumnName, - outputMode, initialState)(outputEncoder, initialStateEncoder) - } - /** * Creates a new dataset with updated eventTimeColumn after the transformWithState * logical node. @@ -939,125 +296,238 @@ class KeyValueGroupedDataset[K, V] private[sql]( transformWithStateDataset.logicalPlan))) } - /** - * (Scala-specific) - * Reduces the elements of each group of data using the specified binary function. - * The given function must be commutative and associative or the result may be non-deterministic. - * - * @since 1.6.0 - */ + /** @inheritdoc */ def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = { - val vEncoder = encoderFor[V] val aggregator: TypedColumn[V, V] = new ReduceAggregator[V](f)(vEncoder).toColumn agg(aggregator) } - /** - * (Java-specific) - * Reduces the elements of each group of data using the specified binary function. - * The given function must be commutative and associative or the result may be non-deterministic. - * - * @since 1.6.0 - */ - def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = { - reduceGroups(f.call _) - } - - /** - * Internal helper function for building typed aggregations that return tuples. For simplicity - * and code reuse, we do this without the help of the type system and then use helper functions - * that cast appropriately for the user facing interface. - */ + /** @inheritdoc */ protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - val encoders = columns.map(c => encoderFor(c.encoder)) - val namedColumns = columns.map(c => withInputType(c.named, vExprEnc, dataAttributes)) - val keyColumn = aggKeyColumn(kExprEnc, groupingAttributes) + val keyAgEncoder = agnosticEncoderFor(kEncoder) + val valueExprEncoder = encoderFor(vEncoder) + val encoders = columns.map(c => agnosticEncoderFor(c.encoder)) + val namedColumns = columns.map { c => + withInputType(c.named, valueExprEncoder, dataAttributes) + } + val keyColumn = aggKeyColumn(keyAgEncoder, groupingAttributes) val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) - val execution = new QueryExecution(sparkSession, aggregate) + new Dataset(sparkSession, aggregate, ProductEncoder.tuple(keyAgEncoder +: encoders)) + } + + /** @inheritdoc */ + def cogroupSorted[U, R : Encoder]( + other: KeyValueGroupedDataset[K, U])( + thisSortExprs: Column*)( + otherSortExprs: Column*)( + f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = { + implicit val uEncoder = other.vEncoderImpl + Dataset[R]( + sparkSession, + CoGroup( + f, + this.groupingAttributes, + other.groupingAttributes, + this.dataAttributes, + other.dataAttributes, + MapGroups.sortOrder(thisSortExprs.map(_.expr)), + MapGroups.sortOrder(otherSortExprs.map(_.expr)), + this.logicalPlan, + other.logicalPlan)) + } - new Dataset(execution, ExpressionEncoder.tuple(kExprEnc +: encoders)) + override def toString: String = { + val builder = new StringBuilder + val kFields = kEncoder.schema.map { f => + s"${f.name}: ${f.dataType.simpleString(2)}" + } + val vFields = vEncoder.schema.map { f => + s"${f.name}: ${f.dataType.simpleString(2)}" + } + builder.append("KeyValueGroupedDataset: [key: [") + builder.append(kFields.take(2).mkString(", ")) + if (kFields.length > 2) { + builder.append(" ... " + (kFields.length - 2) + " more field(s)") + } + builder.append("], value: [") + builder.append(vFields.take(2).mkString(", ")) + if (vFields.length > 2) { + builder.append(" ... " + (vFields.length - 2) + " more field(s)") + } + builder.append("]]").toString() } - /** - * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key - * and the result of computing this aggregation over all elements in the group. - * - * @since 1.6.0 - */ - def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = - aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] + //////////////////////////////////////////////////////////////////////////// + // Return type overrides to make sure we return the implementation instead + // of the interface. + //////////////////////////////////////////////////////////////////////////// + /** @inheritdoc */ + override def mapValues[W]( + func: MapFunction[V, W], + encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = super.mapValues(func, encoder) - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key - * and the result of computing these aggregations over all elements in the group. - * - * @since 1.6.0 - */ - def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = - aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] + /** @inheritdoc */ + override def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = + super.flatMapGroups(f) - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key - * and the result of computing these aggregations over all elements in the group. - * - * @since 1.6.0 - */ - def agg[U1, U2, U3]( + /** @inheritdoc */ + override def flatMapGroups[U]( + f: FlatMapGroupsFunction[K, V, U], + encoder: Encoder[U]): Dataset[U] = super.flatMapGroups(f, encoder) + + /** @inheritdoc */ + override def flatMapSortedGroups[U]( + SortExprs: Array[Column], + f: FlatMapGroupsFunction[K, V, U], + encoder: Encoder[U]): Dataset[U] = super.flatMapSortedGroups(SortExprs, f, encoder) + + /** @inheritdoc */ + override def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = super.mapGroups(f) + + /** @inheritdoc */ + override def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = + super.mapGroups(f, encoder) + + /** @inheritdoc */ + override def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U]): Dataset[U] = + super.mapGroupsWithState(func, stateEncoder, outputEncoder) + + /** @inheritdoc */ + override def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout): Dataset[U] = + super.mapGroupsWithState(func, stateEncoder, outputEncoder, timeoutConf) + + /** @inheritdoc */ + override def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout, + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = + super.mapGroupsWithState(func, stateEncoder, outputEncoder, timeoutConf, initialState) + + /** @inheritdoc */ + override def flatMapGroupsWithState[S, U]( + func: FlatMapGroupsWithStateFunction[K, V, S, U], + outputMode: OutputMode, + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout): Dataset[U] = + super.flatMapGroupsWithState(func, outputMode, stateEncoder, outputEncoder, timeoutConf) + + /** @inheritdoc */ + override def flatMapGroupsWithState[S, U]( + func: FlatMapGroupsWithStateFunction[K, V, S, U], + outputMode: OutputMode, + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout, + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = + super.flatMapGroupsWithState( + func, + outputMode, + stateEncoder, + outputEncoder, + timeoutConf, + initialState) + + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder]( + statefulProcessor: StatefulProcessor[K, V, U], + timeMode: TimeMode, + outputMode: OutputMode, + outputEncoder: Encoder[U]) = + super.transformWithState(statefulProcessor, timeMode, outputMode, outputEncoder) + + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder]( + statefulProcessor: StatefulProcessor[K, V, U], + eventTimeColumnName: String, + outputMode: OutputMode, + outputEncoder: Encoder[U]) = + super.transformWithState(statefulProcessor, eventTimeColumnName, outputMode, outputEncoder) + + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder, S: Encoder]( + statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], + timeMode: TimeMode, + outputMode: OutputMode, + initialState: KeyValueGroupedDataset[K, S], + outputEncoder: Encoder[U], + initialStateEncoder: Encoder[S]) = super.transformWithState( + statefulProcessor, + timeMode, + outputMode, + initialState, + outputEncoder, + initialStateEncoder) + + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder, S: Encoder]( + statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], + outputMode: OutputMode, + initialState: KeyValueGroupedDataset[K, S], + eventTimeColumnName: String, + outputEncoder: Encoder[U], + initialStateEncoder: Encoder[S]) = super.transformWithState( + statefulProcessor, + outputMode, + initialState, + eventTimeColumnName, + outputEncoder, + initialStateEncoder) + + /** @inheritdoc */ + override def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = super.reduceGroups(f) + + /** @inheritdoc */ + override def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = super.agg(col1) + + /** @inheritdoc */ + override def agg[U1, U2]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = super.agg(col1, col2) + + /** @inheritdoc */ + override def agg[U1, U2, U3]( col1: TypedColumn[V, U1], col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = - aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] + col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = super.agg(col1, col2, col3) - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key - * and the result of computing these aggregations over all elements in the group. - * - * @since 1.6.0 - */ - def agg[U1, U2, U3, U4]( + /** @inheritdoc */ + override def agg[U1, U2, U3, U4]( col1: TypedColumn[V, U1], col2: TypedColumn[V, U2], col3: TypedColumn[V, U3], - col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = - aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] + col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = super.agg(col1, col2, col3, col4) - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key - * and the result of computing these aggregations over all elements in the group. - * - * @since 3.0.0 - */ - def agg[U1, U2, U3, U4, U5]( + /** @inheritdoc */ + override def agg[U1, U2, U3, U4, U5]( col1: TypedColumn[V, U1], col2: TypedColumn[V, U2], col3: TypedColumn[V, U3], col4: TypedColumn[V, U4], col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] = - aggUntyped(col1, col2, col3, col4, col5).asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5)]] + super.agg(col1, col2, col3, col4, col5) - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key - * and the result of computing these aggregations over all elements in the group. - * - * @since 3.0.0 - */ - def agg[U1, U2, U3, U4, U5, U6]( + /** @inheritdoc */ + override def agg[U1, U2, U3, U4, U5, U6]( col1: TypedColumn[V, U1], col2: TypedColumn[V, U2], col3: TypedColumn[V, U3], col4: TypedColumn[V, U4], col5: TypedColumn[V, U5], col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] = - aggUntyped(col1, col2, col3, col4, col5, col6) - .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6)]] + super.agg(col1, col2, col3, col4, col5, col6) - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key - * and the result of computing these aggregations over all elements in the group. - * - * @since 3.0.0 - */ - def agg[U1, U2, U3, U4, U5, U6, U7]( + /** @inheritdoc */ + override def agg[U1, U2, U3, U4, U5, U6, U7]( col1: TypedColumn[V, U1], col2: TypedColumn[V, U2], col3: TypedColumn[V, U3], @@ -1065,16 +535,10 @@ class KeyValueGroupedDataset[K, V] private[sql]( col5: TypedColumn[V, U5], col6: TypedColumn[V, U6], col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] = - aggUntyped(col1, col2, col3, col4, col5, col6, col7) - .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7)]] + super.agg(col1, col2, col3, col4, col5, col6, col7) - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key - * and the result of computing these aggregations over all elements in the group. - * - * @since 3.0.0 - */ - def agg[U1, U2, U3, U4, U5, U6, U7, U8]( + /** @inheritdoc */ + override def agg[U1, U2, U3, U4, U5, U6, U7, U8]( col1: TypedColumn[V, U1], col2: TypedColumn[V, U2], col3: TypedColumn[V, U3], @@ -1083,146 +547,30 @@ class KeyValueGroupedDataset[K, V] private[sql]( col6: TypedColumn[V, U6], col7: TypedColumn[V, U7], col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] = - aggUntyped(col1, col2, col3, col4, col5, col6, col7, col8) - .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)]] + super.agg(col1, col2, col3, col4, col5, col6, col7, col8) - /** - * Returns a [[Dataset]] that contains a tuple with each key and the number of items present - * for that key. - * - * @since 1.6.0 - */ - def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long]())) + /** @inheritdoc */ + override def count(): Dataset[(K, Long)] = super.count() - /** - * (Scala-specific) - * Applies the given function to each cogrouped data. For each unique group, the function will - * be passed the grouping key and 2 iterators containing all elements in the group from - * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an - * arbitrary type which will be returned as a new [[Dataset]]. - * - * @since 1.6.0 - */ - def cogroup[U, R : Encoder]( + /** @inheritdoc */ + override def cogroup[U, R: Encoder]( other: KeyValueGroupedDataset[K, U])( - f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = { - implicit val uEncoder = other.vExprEnc - Dataset[R]( - sparkSession, - CoGroup( - f, - this.groupingAttributes, - other.groupingAttributes, - this.dataAttributes, - other.dataAttributes, - Seq.empty, - Seq.empty, - this.logicalPlan, - other.logicalPlan)) - } + f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = + super.cogroup(other)(f) - /** - * (Java-specific) - * Applies the given function to each cogrouped data. For each unique group, the function will - * be passed the grouping key and 2 iterators containing all elements in the group from - * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an - * arbitrary type which will be returned as a new [[Dataset]]. - * - * @since 1.6.0 - */ - def cogroup[U, R]( + /** @inheritdoc */ + override def cogroup[U, R]( other: KeyValueGroupedDataset[K, U], f: CoGroupFunction[K, V, U, R], - encoder: Encoder[R]): Dataset[R] = { - cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) - } - - /** - * (Scala-specific) - * Applies the given function to each sorted cogrouped data. For each unique group, the function - * will be passed the grouping key and 2 sorted iterators containing all elements in the group - * from [[Dataset]] `this` and `other`. The function can return an iterator containing elements - * of an arbitrary type which will be returned as a new [[Dataset]]. - * - * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators - * to be sorted according to the given sort expressions. That sorting does not add - * computational complexity. - * - * @see [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]] - * @since 3.4.0 - */ - def cogroupSorted[U, R : Encoder]( - other: KeyValueGroupedDataset[K, U])( - thisSortExprs: Column*)( - otherSortExprs: Column*)( - f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = { - def toSortOrder(col: Column): SortOrder = col.expr match { - case expr: SortOrder => expr - case expr: Expression => SortOrder(expr, Ascending) - } + encoder: Encoder[R]): Dataset[R] = + super.cogroup(other, f, encoder) - val thisSortOrder: Seq[SortOrder] = thisSortExprs.map(toSortOrder) - val otherSortOrder: Seq[SortOrder] = otherSortExprs.map(toSortOrder) - - implicit val uEncoder = other.vExprEnc - Dataset[R]( - sparkSession, - CoGroup( - f, - this.groupingAttributes, - other.groupingAttributes, - this.dataAttributes, - other.dataAttributes, - thisSortOrder, - otherSortOrder, - this.logicalPlan, - other.logicalPlan)) - } - - /** - * (Java-specific) - * Applies the given function to each sorted cogrouped data. For each unique group, the function - * will be passed the grouping key and 2 sorted iterators containing all elements in the group - * from [[Dataset]] `this` and `other`. The function can return an iterator containing elements - * of an arbitrary type which will be returned as a new [[Dataset]]. - * - * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators - * to be sorted according to the given sort expressions. That sorting does not add - * computational complexity. - * - * @see [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]] - * @since 3.4.0 - */ - def cogroupSorted[U, R]( + /** @inheritdoc */ + override def cogroupSorted[U, R]( other: KeyValueGroupedDataset[K, U], thisSortExprs: Array[Column], otherSortExprs: Array[Column], f: CoGroupFunction[K, V, U, R], - encoder: Encoder[R]): Dataset[R] = { - import org.apache.spark.util.ArrayImplicits._ - cogroupSorted(other)( - thisSortExprs.toImmutableArraySeq: _*)(otherSortExprs.toImmutableArraySeq: _*)( - (key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) - } - - override def toString: String = { - val builder = new StringBuilder - val kFields = kExprEnc.schema.map { - case f => s"${f.name}: ${f.dataType.simpleString(2)}" - } - val vFields = vExprEnc.schema.map { - case f => s"${f.name}: ${f.dataType.simpleString(2)}" - } - builder.append("KeyValueGroupedDataset: [key: [") - builder.append(kFields.take(2).mkString(", ")) - if (kFields.length > 2) { - builder.append(" ... " + (kFields.length - 2) + " more field(s)") - } - builder.append("], value: [") - builder.append(vFields.take(2).mkString(", ")) - if (vFields.length > 2) { - builder.append(" ... " + (vFields.length - 2) + " more field(s)") - } - builder.append("]]").toString() - } + encoder: Encoder[R]): Dataset[R] = + super.cogroupSorted(other, thisSortExprs, otherSortExprs, f, encoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala deleted file mode 100644 index 6212a7fdb259c..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala +++ /dev/null @@ -1,353 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.SparkRuntimeException -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, InsertAction, InsertStarAction, MergeAction, MergeIntoTable, UpdateAction, UpdateStarAction} -import org.apache.spark.sql.functions.expr - -/** - * `MergeIntoWriter` provides methods to define and execute merge actions based - * on specified conditions. - * - * @tparam T the type of data in the Dataset. - * @param table the name of the target table for the merge operation. - * @param ds the source Dataset to merge into the target table. - * @param on the merge condition. - * @param schemaEvolutionEnabled whether to enable automatic schema evolution for this merge - * operation. Default is `false`. - * - * @since 4.0.0 - */ -@Experimental -class MergeIntoWriter[T] private[sql] ( - table: String, - ds: Dataset[T], - on: Column, - private[sql] val schemaEvolutionEnabled: Boolean = false) { - - private val df: DataFrame = ds.toDF() - - private[sql] val sparkSession = ds.sparkSession - import sparkSession.RichColumn - - private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table) - - private val logicalPlan = df.queryExecution.logical - - private[sql] var matchedActions: Seq[MergeAction] = Seq.empty[MergeAction] - private[sql] var notMatchedActions: Seq[MergeAction] = Seq.empty[MergeAction] - private[sql] var notMatchedBySourceActions: Seq[MergeAction] = Seq.empty[MergeAction] - - /** - * Initialize a `WhenMatched` action without any condition. - * - * This `WhenMatched` action will be executed when a source row matches a target table row based - * on the merge condition. - * - * This `WhenMatched` can be followed by one of the following merge actions: - * - `updateAll`: Update all the matched target table rows with source dataset rows. - * - `update(Map)`: Update all the matched target table rows while changing only - * a subset of columns based on the provided assignment. - * - `delete`: Delete all target rows that have a match in the source table. - * - * @return a new `WhenMatched` object. - */ - def whenMatched(): WhenMatched[T] = { - new WhenMatched[T](this, None) - } - - /** - * Initialize a `WhenMatched` action with a condition. - * - * This `WhenMatched` action will be executed when a source row matches a target table row based - * on the merge condition and the specified `condition` is satisfied. - * - * This `WhenMatched` can be followed by one of the following merge actions: - * - `updateAll`: Update all the matched target table rows with source dataset rows. - * - `update(Map)`: Update all the matched target table rows while changing only - * a subset of columns based on the provided assignment. - * - `delete`: Delete all target rows that have a match in the source table. - * - * @param condition a `Column` representing the condition to be evaluated for the action. - * @return a new `WhenMatched` object configured with the specified condition. - */ - def whenMatched(condition: Column): WhenMatched[T] = { - new WhenMatched[T](this, Some(condition.expr)) - } - - /** - * Initialize a `WhenNotMatched` action without any condition. - * - * This `WhenNotMatched` action will be executed when a source row does not match any target row - * based on the merge condition. - * - * This `WhenNotMatched` can be followed by one of the following merge actions: - * - `insertAll`: Insert all rows from the source that are not already in the target table. - * - `insert(Map)`: Insert all rows from the source that are not already in the target table, - * with the specified columns based on the provided assignment. - * - * @return a new `WhenNotMatched` object. - */ - def whenNotMatched(): WhenNotMatched[T] = { - new WhenNotMatched[T](this, None) - } - - /** - * Initialize a `WhenNotMatched` action with a condition. - * - * This `WhenNotMatched` action will be executed when a source row does not match any target row - * based on the merge condition and the specified `condition` is satisfied. - * - * This `WhenNotMatched` can be followed by one of the following merge actions: - * - `insertAll`: Insert all rows from the source that are not already in the target table. - * - `insert(Map)`: Insert all rows from the source that are not already in the target table, - * with the specified columns based on the provided assignment. - * - * @param condition a `Column` representing the condition to be evaluated for the action. - * @return a new `WhenNotMatched` object configured with the specified condition. - */ - def whenNotMatched(condition: Column): WhenNotMatched[T] = { - new WhenNotMatched[T](this, Some(condition.expr)) - } - - /** - * Initialize a `WhenNotMatchedBySource` action without any condition. - * - * This `WhenNotMatchedBySource` action will be executed when a target row does not match any - * rows in the source table based on the merge condition. - * - * This `WhenNotMatchedBySource` can be followed by one of the following merge actions: - * - `updateAll`: Update all the not matched target table rows with source dataset rows. - * - `update(Map)`: Update all the not matched target table rows while changing only - * the specified columns based on the provided assignment. - * - `delete`: Delete all target rows that have no matches in the source table. - * - * @return a new `WhenNotMatchedBySource` object. - */ - def whenNotMatchedBySource(): WhenNotMatchedBySource[T] = { - new WhenNotMatchedBySource[T](this, None) - } - - /** - * Initialize a `WhenNotMatchedBySource` action with a condition. - * - * This `WhenNotMatchedBySource` action will be executed when a target row does not match any - * rows in the source table based on the merge condition and the specified `condition` - * is satisfied. - * - * This `WhenNotMatchedBySource` can be followed by one of the following merge actions: - * - `updateAll`: Update all the not matched target table rows with source dataset rows. - * - `update(Map)`: Update all the not matched target table rows while changing only - * the specified columns based on the provided assignment. - * - `delete`: Delete all target rows that have no matches in the source table. - * - * @param condition a `Column` representing the condition to be evaluated for the action. - * @return a new `WhenNotMatchedBySource` object configured with the specified condition. - */ - def whenNotMatchedBySource(condition: Column): WhenNotMatchedBySource[T] = { - new WhenNotMatchedBySource[T](this, Some(condition.expr)) - } - - /** - * Enable automatic schema evolution for this merge operation. - * @return A `MergeIntoWriter` instance with schema evolution enabled. - */ - def withSchemaEvolution(): MergeIntoWriter[T] = { - new MergeIntoWriter[T](this.table, this.ds, this.on, schemaEvolutionEnabled = true) - .withNewMatchedActions(this.matchedActions: _*) - .withNewNotMatchedActions(this.notMatchedActions: _*) - .withNewNotMatchedBySourceActions(this.notMatchedBySourceActions: _*) - } - - /** - * Executes the merge operation. - */ - def merge(): Unit = { - if (matchedActions.isEmpty && notMatchedActions.isEmpty && notMatchedBySourceActions.isEmpty) { - throw new SparkRuntimeException( - errorClass = "NO_MERGE_ACTION_SPECIFIED", - messageParameters = Map.empty) - } - - val merge = MergeIntoTable( - UnresolvedRelation(tableName).requireWritePrivileges(MergeIntoTable.getWritePrivileges( - matchedActions, notMatchedActions, notMatchedBySourceActions)), - logicalPlan, - on.expr, - matchedActions, - notMatchedActions, - notMatchedBySourceActions, - schemaEvolutionEnabled) - val qe = sparkSession.sessionState.executePlan(merge) - qe.assertCommandExecuted() - } - - private[sql] def withNewMatchedActions(actions: MergeAction*): MergeIntoWriter[T] = { - this.matchedActions ++= actions - this - } - - private[sql] def withNewNotMatchedActions(actions: MergeAction*): MergeIntoWriter[T] = { - this.notMatchedActions ++= actions - this - } - - private[sql] def withNewNotMatchedBySourceActions(actions: MergeAction*): MergeIntoWriter[T] = { - this.notMatchedBySourceActions ++= actions - this - } -} - -/** - * A class for defining actions to be taken when matching rows in a DataFrame during - * a merge operation. - * - * @param mergeIntoWriter The MergeIntoWriter instance responsible for writing data to a - * target DataFrame. - * @param condition An optional condition Expression that specifies when the actions - * should be applied. - * If the condition is None, the actions will be applied to all matched - * rows. - * - * @tparam T The type of data in the MergeIntoWriter. - */ -case class WhenMatched[T] private[sql]( - mergeIntoWriter: MergeIntoWriter[T], - condition: Option[Expression]) { - import mergeIntoWriter.sparkSession.RichColumn - - /** - * Specifies an action to update all matched rows in the DataFrame. - * - * @return The MergeIntoWriter instance with the update all action configured. - */ - def updateAll(): MergeIntoWriter[T] = { - mergeIntoWriter.withNewMatchedActions(UpdateStarAction(condition)) - } - - /** - * Specifies an action to update matched rows in the DataFrame with the provided column - * assignments. - * - * @param map A Map of column names to Column expressions representing the updates to be applied. - * @return The MergeIntoWriter instance with the update action configured. - */ - def update(map: Map[String, Column]): MergeIntoWriter[T] = { - mergeIntoWriter.withNewMatchedActions( - UpdateAction(condition, map.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq)) - } - - /** - * Specifies an action to delete matched rows from the DataFrame. - * - * @return The MergeIntoWriter instance with the delete action configured. - */ - def delete(): MergeIntoWriter[T] = { - mergeIntoWriter.withNewMatchedActions(DeleteAction(condition)) - } -} - -/** - * A class for defining actions to be taken when no matching rows are found in a DataFrame - * during a merge operation. - * - * @param MergeIntoWriter The MergeIntoWriter instance responsible for writing data to a - * target DataFrame. - * @param condition An optional condition Expression that specifies when the actions - * defined in this configuration should be applied. - * If the condition is None, the actions will be applied when there - * are no matching rows. - * @tparam T The type of data in the MergeIntoWriter. - */ -case class WhenNotMatched[T] private[sql]( - mergeIntoWriter: MergeIntoWriter[T], - condition: Option[Expression]) { - import mergeIntoWriter.sparkSession.RichColumn - - /** - * Specifies an action to insert all non-matched rows into the DataFrame. - * - * @return The MergeIntoWriter instance with the insert all action configured. - */ - def insertAll(): MergeIntoWriter[T] = { - mergeIntoWriter.withNewNotMatchedActions(InsertStarAction(condition)) - } - - /** - * Specifies an action to insert non-matched rows into the DataFrame with the provided - * column assignments. - * - * @param map A Map of column names to Column expressions representing the values to be inserted. - * @return The MergeIntoWriter instance with the insert action configured. - */ - def insert(map: Map[String, Column]): MergeIntoWriter[T] = { - mergeIntoWriter.withNewNotMatchedActions( - InsertAction(condition, map.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq)) - } -} - - -/** - * A class for defining actions to be performed when there is no match by source - * during a merge operation in a MergeIntoWriter. - * - * @param MergeIntoWriter the MergeIntoWriter instance to which the merge actions will be applied. - * @param condition an optional condition to be used with the merge actions. - * @tparam T the type parameter for the MergeIntoWriter. - */ -case class WhenNotMatchedBySource[T] private[sql]( - mergeIntoWriter: MergeIntoWriter[T], - condition: Option[Expression]) { - import mergeIntoWriter.sparkSession.RichColumn - - /** - * Specifies an action to update all non-matched rows in the target DataFrame when - * not matched by the source. - * - * @return The MergeIntoWriter instance with the update all action configured. - */ - def updateAll(): MergeIntoWriter[T] = { - mergeIntoWriter.withNewNotMatchedBySourceActions(UpdateStarAction(condition)) - } - - /** - * Specifies an action to update non-matched rows in the target DataFrame with the provided - * column assignments when not matched by the source. - * - * @param map A Map of column names to Column expressions representing the updates to be applied. - * @return The MergeIntoWriter instance with the update action configured. - */ - def update(map: Map[String, Column]): MergeIntoWriter[T] = { - mergeIntoWriter.withNewNotMatchedBySourceActions( - UpdateAction(condition, map.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq)) - } - - /** - * Specifies an action to delete non-matched rows from the target DataFrame when not matched by - * the source. - * - * @return The MergeIntoWriter instance with the delete action configured. - */ - def delete(): MergeIntoWriter[T] = { - mergeIntoWriter.withNewNotMatchedBySourceActions(DeleteAction(condition)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 777baa3e62687..bd47a21a1e09b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -22,13 +22,13 @@ import org.apache.spark.annotation.Stable import org.apache.spark.api.python.PythonEvalType import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias -import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.internal.ExpressionUtils.{column, generateAlias} @@ -53,8 +53,8 @@ class RelationalGroupedDataset protected[sql]( protected[sql] val df: DataFrame, private[sql] val groupingExprs: Seq[Expression], groupType: RelationalGroupedDataset.GroupType) - extends api.RelationalGroupedDataset[Dataset] { - type RGD = RelationalGroupedDataset + extends api.RelationalGroupedDataset { + import RelationalGroupedDataset._ import df.sparkSession._ @@ -117,23 +117,14 @@ class RelationalGroupedDataset protected[sql]( columnExprs.map(column) } - - /** - * Returns a `KeyValueGroupedDataset` where the data is grouped by the grouping expressions - * of current `RelationalGroupedDataset`. - * - * @since 3.0.0 - */ + /** @inheritdoc */ def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = { - val keyEncoder = encoderFor[K] - val valueEncoder = encoderFor[T] - val (qe, groupingAttributes) = handleGroupingExpression(df.logicalPlan, df.sparkSession, groupingExprs) new KeyValueGroupedDataset( - keyEncoder, - valueEncoder, + implicitly[Encoder[K]], + implicitly[Encoder[T]], qe, df.logicalPlan.output, groupingAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 55f67da68221f..137dbaed9f00a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql import java.net.URI import java.nio.file.Paths import java.util.{ServiceLoader, UUID} +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} +import scala.concurrent.duration.DurationInt import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -57,7 +59,7 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager -import org.apache.spark.util.{CallSite, SparkFileUtils, Utils} +import org.apache.spark.util.{CallSite, SparkFileUtils, ThreadUtils, Utils} import org.apache.spark.util.ArrayImplicits._ /** @@ -92,8 +94,9 @@ class SparkSession private( @transient private val existingSharedState: Option[SharedState], @transient private val parentSessionState: Option[SessionState], @transient private[sql] val extensions: SparkSessionExtensions, - @transient private[sql] val initialSessionOptions: Map[String, String]) - extends api.SparkSession[Dataset] with Logging { self => + @transient private[sql] val initialSessionOptions: Map[String, String], + @transient private val parentManagedJobTags: Map[String, String]) + extends api.SparkSession with Logging { self => // The call site where this SparkSession was constructed. private val creationSite: CallSite = Utils.getCallSite() @@ -107,7 +110,12 @@ class SparkSession private( private[sql] def this( sc: SparkContext, initialSessionOptions: java.util.HashMap[String, String]) = { - this(sc, None, None, applyAndLoadExtensions(sc), initialSessionOptions.asScala.toMap) + this( + sc, + existingSharedState = None, + parentSessionState = None, + applyAndLoadExtensions(sc), initialSessionOptions.asScala.toMap, + parentManagedJobTags = Map.empty) } private[sql] def this(sc: SparkContext) = this(sc, new java.util.HashMap[String, String]()) @@ -122,6 +130,18 @@ class SparkSession private( .getOrElse(SQLConf.getFallbackConf) }) + /** Tag to mark all jobs owned by this session. */ + private[sql] lazy val sessionJobTag = s"spark-session-$sessionUUID" + + /** + * A map to hold the mapping from user-defined tags to the real tags attached to Jobs. + * Real tag have the current session ID attached: `"tag1" -> s"spark-session-$sessionUUID-tag1"`. + */ + @transient + private[sql] lazy val managedJobTags: ConcurrentHashMap[String, String] = { + new ConcurrentHashMap(parentManagedJobTags.asJava) + } + /** @inheritdoc */ def version: String = SPARK_VERSION @@ -173,16 +193,8 @@ class SparkSession private( @transient val sqlContext: SQLContext = new SQLContext(this) - /** - * Runtime configuration interface for Spark. - * - * This is the interface through which the user can get and set all Spark and Hadoop - * configurations that are relevant to Spark SQL. When getting the value of a config, - * this defaults to the value set in the underlying `SparkContext`, if any. - * - * @since 2.0.0 - */ - @transient lazy val conf: RuntimeConfig = new RuntimeConfig(sessionState.conf) + /** @inheritdoc */ + @transient lazy val conf: RuntimeConfig = new RuntimeConfigImpl(sessionState.conf) /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s @@ -243,7 +255,8 @@ class SparkSession private( Some(sharedState), parentSessionState = None, extensions, - initialSessionOptions) + initialSessionOptions, + parentManagedJobTags = Map.empty) } /** @@ -264,8 +277,10 @@ class SparkSession private( Some(sharedState), Some(sessionState), extensions, - Map.empty) + Map.empty, + managedJobTags.asScala.toMap) result.sessionState // force copy of SessionState + result.managedJobTags // force copy of userDefinedToRealTagsMap result } @@ -480,12 +495,7 @@ class SparkSession private( | Catalog-related methods | * ------------------------- */ - /** - * Interface through which the user may create, drop, alter or query underlying - * databases, tables, functions etc. - * - * @since 2.0.0 - */ + /** @inheritdoc */ @transient lazy val catalog: Catalog = new CatalogImpl(self) /** @inheritdoc */ @@ -649,16 +659,84 @@ class SparkSession private( artifactManager.addLocalArtifacts(uri.flatMap(Artifact.parseArtifacts)) } + /** @inheritdoc */ + override def addTag(tag: String): Unit = { + SparkContext.throwIfInvalidTag(tag) + managedJobTags.put(tag, s"spark-session-$sessionUUID-$tag") + } + + /** @inheritdoc */ + override def removeTag(tag: String): Unit = managedJobTags.remove(tag) + + /** @inheritdoc */ + override def getTags(): Set[String] = managedJobTags.keys().asScala.toSet + + /** @inheritdoc */ + override def clearTags(): Unit = managedJobTags.clear() + /** - * Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a - * `DataFrame`. - * {{{ - * sparkSession.read.parquet("/path/to/file.parquet") - * sparkSession.read.schema(schema).json("/path/to/file.json") - * }}} + * Request to interrupt all currently running SQL operations of this session. * - * @since 2.0.0 + * @note Only DataFrame/SQL operations started by this session can be interrupted. + * + * @note This method will wait up to 60 seconds for the interruption request to be issued. + + * @return Sequence of SQL execution IDs requested to be interrupted. + + * @since 4.0.0 */ + override def interruptAll(): Seq[String] = + doInterruptTag(sessionJobTag, "as part of cancellation of all jobs") + + /** + * Request to interrupt all currently running SQL operations of this session with the given + * job tag. + * + * @note Only DataFrame/SQL operations started by this session can be interrupted. + * + * @note This method will wait up to 60 seconds for the interruption request to be issued. + * + * @return Sequence of SQL execution IDs requested to be interrupted. + + * @since 4.0.0 + */ + override def interruptTag(tag: String): Seq[String] = { + val realTag = managedJobTags.get(tag) + if (realTag == null) return Seq.empty + doInterruptTag(realTag, s"part of cancelled job tags $tag") + } + + private def doInterruptTag(tag: String, reason: String): Seq[String] = { + val cancelledTags = + sparkContext.cancelJobsWithTagWithFuture(tag, reason) + + ThreadUtils.awaitResult(cancelledTags, 60.seconds) + .flatMap(job => Option(job.properties.getProperty(SQLExecution.EXECUTION_ROOT_ID_KEY))) + } + + /** + * Request to interrupt a SQL operation of this session, given its SQL execution ID. + * + * @note Only DataFrame/SQL operations started by this session can be interrupted. + * + * @note This method will wait up to 60 seconds for the interruption request to be issued. + * + * @return The execution ID requested to be interrupted, as a single-element sequence, or an empty + * sequence if the operation is not started by this session. + * + * @since 4.0.0 + */ + override def interruptOperation(operationId: String): Seq[String] = { + scala.util.Try(operationId.toLong).toOption match { + case Some(executionIdToBeCancelled) => + val tagToBeCancelled = SQLExecution.executionIdJobTag(this, executionIdToBeCancelled) + doInterruptTag(tagToBeCancelled, reason = "") + case None => + throw new IllegalArgumentException("executionId must be a number in string form.") + } + } + + /** @inheritdoc */ def read: DataFrameReader = new DataFrameReader(self) /** @@ -744,7 +822,7 @@ class SparkSession private( } /** - * Execute a block of code with the this session set as the active session, and restore the + * Execute a block of code with this session set as the active session, and restore the * previous session on completion. */ private[sql] def withActive[T](block: => T): T = { @@ -759,7 +837,8 @@ class SparkSession private( } private[sql] def leafNodeDefaultParallelism: Int = { - conf.get(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM).getOrElse(sparkContext.defaultParallelism) + sessionState.conf.getConf(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM) + .getOrElse(sparkContext.defaultParallelism) } private[sql] object Converter extends ColumnNodeToExpressionConverter with Serializable { @@ -979,7 +1058,12 @@ object SparkSession extends Logging { loadExtensions(extensions) applyExtensions(sparkContext, extensions) - session = new SparkSession(sparkContext, None, None, extensions, options.toMap) + session = new SparkSession(sparkContext, + existingSharedState = None, + parentSessionState = None, + extensions, + initialSessionOptions = options.toMap, + parentManagedJobTags = Map.empty) setDefaultSession(session) setActiveSession(session) registerContextListener(sparkContext) @@ -1124,13 +1208,13 @@ object SparkSession extends Logging { private[sql] def getOrCloneSessionWithConfigsOff( session: SparkSession, configurations: Seq[ConfigEntry[Boolean]]): SparkSession = { - val configsEnabled = configurations.filter(session.conf.get[Boolean]) + val configsEnabled = configurations.filter(session.sessionState.conf.getConf[Boolean]) if (configsEnabled.isEmpty) { session } else { val newSession = session.cloneSession() configsEnabled.foreach(conf => { - newSession.conf.set(conf, false) + newSession.sessionState.conf.setConf(conf, false) }) newSession } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index c1c9af2ea4273..bc270e6ac64ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -29,7 +29,7 @@ import org.apache.spark.internal.LogKeys.CLASS_LOADER import org.apache.spark.security.SocketAuthServer import org.apache.spark.sql.{internal, Column, DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRegistry} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -69,7 +69,10 @@ private[sql] object PythonSQLUtils extends Logging { // This is needed when generating SQL documentation for built-in functions. def listBuiltinFunctionInfos(): Array[ExpressionInfo] = { - FunctionRegistry.functionSet.flatMap(f => FunctionRegistry.builtin.lookupFunction(f)).toArray + (FunctionRegistry.functionSet.flatMap(f => FunctionRegistry.builtin.lookupFunction(f)) ++ + TableFunctionRegistry.functionSet.flatMap( + f => TableFunctionRegistry.builtin.lookupFunction(f))). + groupBy(_.getName).map(v => v._2.head).toArray } private def listAllSQLConfigs(): Seq[(String, String, String, String)] = { @@ -152,6 +155,9 @@ private[sql] object PythonSQLUtils extends Logging { def collect_top_k(e: Column, num: Int, reverse: Boolean): Column = Column.internalFn("collect_top_k", e, lit(num), lit(reverse)) + def binary_search(e: Column, value: Column): Column = + Column.internalFn("array_binary_search", e, value) + def pandasProduct(e: Column, ignoreNA: Boolean): Column = Column.internalFn("pandas_product", e, lit(ignoreNA)) @@ -173,6 +179,13 @@ private[sql] object PythonSQLUtils extends Logging { def pandasCovar(col1: Column, col2: Column, ddof: Int): Column = Column.internalFn("pandas_covar", col1, col2, lit(ddof)) + /** + * A long column that increases one by one. + * This is for 'distributed-sequence' default index in pandas API on Spark. + */ + def distributed_sequence_id(): Column = + Column.internalFn("distributed_sequence_id") + def unresolvedNamedLambdaVariable(name: String): Column = Column(internal.UnresolvedNamedLambdaVariable.apply(name)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala index 89aa7cfd56071..1ee960622fc2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala @@ -64,11 +64,12 @@ class ArtifactManager(session: SparkSession) extends Logging { // The base directory/URI where all artifacts are stored for this `sessionUUID`. protected[artifact] val (artifactPath, artifactURI): (Path, String) = (ArtifactUtils.concatenatePaths(artifactRootPath, session.sessionUUID), - s"$artifactRootURI/${session.sessionUUID}") + s"$artifactRootURI${File.separator}${session.sessionUUID}") // The base directory/URI where all class file artifacts are stored for this `sessionUUID`. protected[artifact] val (classDir, classURI): (Path, String) = - (ArtifactUtils.concatenatePaths(artifactPath, "classes"), s"$artifactURI/classes/") + (ArtifactUtils.concatenatePaths(artifactPath, "classes"), + s"$artifactURI${File.separator}classes${File.separator}") protected[artifact] val state: JobArtifactState = JobArtifactState(session.sessionUUID, Option(classURI)) @@ -112,6 +113,14 @@ class ArtifactManager(session: SparkSession) extends Logging { } } + private def normalizePath(path: Path): Path = { + // Convert the path to a string with the current system's separator + val normalizedPathString = path.toString + .replace('/', File.separatorChar) + .replace('\\', File.separatorChar) + // Convert the normalized string back to a Path object + Paths.get(normalizedPathString).normalize() + } /** * Add and prepare a staged artifact (i.e an artifact that has been rebuilt locally from bytes * over the wire) for use. @@ -128,14 +137,14 @@ class ArtifactManager(session: SparkSession) extends Logging { deleteStagedFile: Boolean = true ): Unit = JobArtifactSet.withActiveJobArtifactState(state) { require(!remoteRelativePath.isAbsolute) - - if (remoteRelativePath.startsWith(s"cache${File.separator}")) { + val normalizedRemoteRelativePath = normalizePath(remoteRelativePath) + if (normalizedRemoteRelativePath.startsWith(s"cache${File.separator}")) { val tmpFile = serverLocalStagingPath.toFile Utils.tryWithSafeFinallyAndFailureCallbacks { val blockManager = session.sparkContext.env.blockManager val blockId = CacheId( sessionUUID = session.sessionUUID, - hash = remoteRelativePath.toString.stripPrefix(s"cache${File.separator}")) + hash = normalizedRemoteRelativePath.toString.stripPrefix(s"cache${File.separator}")) val updater = blockManager.TempFileBasedBlockStoreUpdater( blockId = blockId, level = StorageLevel.MEMORY_AND_DISK_SER, @@ -145,11 +154,11 @@ class ArtifactManager(session: SparkSession) extends Logging { tellMaster = false) updater.save() }(catchBlock = { tmpFile.delete() }) - } else if (remoteRelativePath.startsWith(s"classes${File.separator}")) { + } else if (normalizedRemoteRelativePath.startsWith(s"classes${File.separator}")) { // Move class files to the right directory. val target = ArtifactUtils.concatenatePaths( classDir, - remoteRelativePath.toString.stripPrefix(s"classes${File.separator}")) + normalizedRemoteRelativePath.toString.stripPrefix(s"classes${File.separator}")) // Allow overwriting class files to capture updates to classes. // This is required because the client currently sends all the class files in each class file // transfer. @@ -159,7 +168,7 @@ class ArtifactManager(session: SparkSession) extends Logging { allowOverwrite = true, deleteSource = deleteStagedFile) } else { - val target = ArtifactUtils.concatenatePaths(artifactPath, remoteRelativePath) + val target = ArtifactUtils.concatenatePaths(artifactPath, normalizedRemoteRelativePath) // Disallow overwriting with modified version if (Files.exists(target)) { // makes the query idempotent @@ -167,30 +176,30 @@ class ArtifactManager(session: SparkSession) extends Logging { return } - throw new RuntimeException(s"Duplicate Artifact: $remoteRelativePath. " + + throw new RuntimeException(s"Duplicate Artifact: $normalizedRemoteRelativePath. " + "Artifacts cannot be overwritten.") } transferFile(serverLocalStagingPath, target, deleteSource = deleteStagedFile) // This URI is for Spark file server that starts with "spark://". val uri = s"$artifactURI/${Utils.encodeRelativeUnixPathToURIRawPath( - FilenameUtils.separatorsToUnix(remoteRelativePath.toString))}" + FilenameUtils.separatorsToUnix(normalizedRemoteRelativePath.toString))}" - if (remoteRelativePath.startsWith(s"jars${File.separator}")) { + if (normalizedRemoteRelativePath.startsWith(s"jars${File.separator}")) { session.sparkContext.addJar(uri) jarsList.add(target) - } else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) { + } else if (normalizedRemoteRelativePath.startsWith(s"pyfiles${File.separator}")) { session.sparkContext.addFile(uri) - val stringRemotePath = remoteRelativePath.toString + val stringRemotePath = normalizedRemoteRelativePath.toString if (stringRemotePath.endsWith(".zip") || stringRemotePath.endsWith( ".egg") || stringRemotePath.endsWith(".jar")) { pythonIncludeList.add(target.getFileName.toString) } - } else if (remoteRelativePath.startsWith(s"archives${File.separator}")) { + } else if (normalizedRemoteRelativePath.startsWith(s"archives${File.separator}")) { val canonicalUri = fragment.map(Utils.getUriBuilder(new URI(uri)).fragment).getOrElse(new URI(uri)) session.sparkContext.addArchive(canonicalUri.toString) - } else if (remoteRelativePath.startsWith(s"files${File.separator}")) { + } else if (normalizedRemoteRelativePath.startsWith(s"files${File.separator}")) { session.sparkContext.addFile(uri) } } @@ -301,20 +310,21 @@ class ArtifactManager(session: SparkSession) extends Logging { def uploadArtifactToFs( remoteRelativePath: Path, serverLocalStagingPath: Path): Unit = { + val normalizedRemoteRelativePath = normalizePath(remoteRelativePath) val hadoopConf = session.sparkContext.hadoopConfiguration assert( - remoteRelativePath.startsWith( + normalizedRemoteRelativePath.startsWith( ArtifactManager.forwardToFSPrefix + File.separator)) val destFSPath = new FSPath( Paths - .get("/") - .resolve(remoteRelativePath.subpath(1, remoteRelativePath.getNameCount)) + .get(File.separator) + .resolve(normalizedRemoteRelativePath.subpath(1, normalizedRemoteRelativePath.getNameCount)) .toString) val localPath = serverLocalStagingPath val fs = destFSPath.getFileSystem(hadoopConf) if (fs.isInstanceOf[LocalFileSystem]) { val allowDestLocalConf = - session.conf.get(SQLConf.ARTIFACT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL) + session.sessionState.conf.getConf(SQLConf.ARTIFACT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL) .getOrElse( session.conf.get("spark.connect.copyFromLocalToFs.allowDestLocal").contains("true")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 676be7fe41cbc..c39018ff06fca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -14,657 +14,153 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.sql.catalog -import scala.jdk.CollectionConverters._ +import java.util -import org.apache.spark.annotation.Stable -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset} +import org.apache.spark.sql.{api, DataFrame, Dataset} +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.types.StructType -import org.apache.spark.storage.StorageLevel - -/** - * Catalog interface for Spark. To access this, use `SparkSession.catalog`. - * - * @since 2.0.0 - */ -@Stable -abstract class Catalog { - - /** - * Returns the current database (namespace) in this session. - * - * @since 2.0.0 - */ - def currentDatabase: String - - /** - * Sets the current database (namespace) in this session. - * - * @since 2.0.0 - */ - def setCurrentDatabase(dbName: String): Unit - - /** - * Returns a list of databases (namespaces) available within the current catalog. - * - * @since 2.0.0 - */ - def listDatabases(): Dataset[Database] - - /** - * Returns a list of databases (namespaces) which name match the specify pattern and - * available within the current catalog. - * - * @since 3.5.0 - */ - def listDatabases(pattern: String): Dataset[Database] - - /** - * Returns a list of tables/views in the current database (namespace). - * This includes all temporary views. - * - * @since 2.0.0 - */ - def listTables(): Dataset[Table] - - /** - * Returns a list of tables/views in the specified database (namespace) (the name can be qualified - * with catalog). - * This includes all temporary views. - * - * @since 2.0.0 - */ - @throws[AnalysisException]("database does not exist") - def listTables(dbName: String): Dataset[Table] - - /** - * Returns a list of tables/views in the specified database (namespace) - * which name match the specify pattern (the name can be qualified with catalog). - * This includes all temporary views. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database does not exist") - def listTables(dbName: String, pattern: String): Dataset[Table] - - /** - * Returns a list of functions registered in the current database (namespace). - * This includes all temporary functions. - * - * @since 2.0.0 - */ - def listFunctions(): Dataset[Function] - - /** - * Returns a list of functions registered in the specified database (namespace) (the name can be - * qualified with catalog). - * This includes all built-in and temporary functions. - * - * @since 2.0.0 - */ - @throws[AnalysisException]("database does not exist") - def listFunctions(dbName: String): Dataset[Function] - - /** - * Returns a list of functions registered in the specified database (namespace) - * which name match the specify pattern (the name can be qualified with catalog). - * This includes all built-in and temporary functions. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database does not exist") - def listFunctions(dbName: String, pattern: String): Dataset[Function] - /** - * Returns a list of columns for the given table/view or temporary view. - * - * @param tableName is either a qualified or unqualified name that designates a table/view. It - * follows the same resolution rule with SQL: search for temp views first then - * table/views in the current database (namespace). - * @since 2.0.0 - */ - @throws[AnalysisException]("table does not exist") - def listColumns(tableName: String): Dataset[Column] +/** @inheritdoc */ +abstract class Catalog extends api.Catalog { + /** @inheritdoc */ + override def listDatabases(): Dataset[Database] - /** - * Returns a list of columns for the given table/view in the specified database under the Hive - * Metastore. - * - * To list columns for table/view in other catalogs, please use `listColumns(tableName)` with - * qualified table/view name instead. - * - * @param dbName is an unqualified name that designates a database. - * @param tableName is an unqualified name that designates a table/view. - * @since 2.0.0 - */ - @throws[AnalysisException]("database or table does not exist") - def listColumns(dbName: String, tableName: String): Dataset[Column] + /** @inheritdoc */ + override def listDatabases(pattern: String): Dataset[Database] - /** - * Get the database (namespace) with the specified name (can be qualified with catalog). This - * throws an AnalysisException when the database (namespace) cannot be found. - * - * @since 2.1.0 - */ - @throws[AnalysisException]("database does not exist") - def getDatabase(dbName: String): Database + /** @inheritdoc */ + override def listTables(): Dataset[Table] - /** - * Get the table or view with the specified name. This table can be a temporary view or a - * table/view. This throws an AnalysisException when no Table can be found. - * - * @param tableName is either a qualified or unqualified name that designates a table/view. It - * follows the same resolution rule with SQL: search for temp views first then - * table/views in the current database (namespace). - * @since 2.1.0 - */ - @throws[AnalysisException]("table does not exist") - def getTable(tableName: String): Table + /** @inheritdoc */ + override def listTables(dbName: String): Dataset[Table] - /** - * Get the table or view with the specified name in the specified database under the Hive - * Metastore. This throws an AnalysisException when no Table can be found. - * - * To get table/view in other catalogs, please use `getTable(tableName)` with qualified table/view - * name instead. - * - * @since 2.1.0 - */ - @throws[AnalysisException]("database or table does not exist") - def getTable(dbName: String, tableName: String): Table + /** @inheritdoc */ + override def listTables(dbName: String, pattern: String): Dataset[Table] - /** - * Get the function with the specified name. This function can be a temporary function or a - * function. This throws an AnalysisException when the function cannot be found. - * - * @param functionName is either a qualified or unqualified name that designates a function. It - * follows the same resolution rule with SQL: search for built-in/temp - * functions first then functions in the current database (namespace). - * @since 2.1.0 - */ - @throws[AnalysisException]("function does not exist") - def getFunction(functionName: String): Function + /** @inheritdoc */ + override def listFunctions(): Dataset[Function] - /** - * Get the function with the specified name in the specified database under the Hive Metastore. - * This throws an AnalysisException when the function cannot be found. - * - * To get functions in other catalogs, please use `getFunction(functionName)` with qualified - * function name instead. - * - * @param dbName is an unqualified name that designates a database. - * @param functionName is an unqualified name that designates a function in the specified database - * @since 2.1.0 - */ - @throws[AnalysisException]("database or function does not exist") - def getFunction(dbName: String, functionName: String): Function + /** @inheritdoc */ + override def listFunctions(dbName: String): Dataset[Function] - /** - * Check if the database (namespace) with the specified name exists (the name can be qualified - * with catalog). - * - * @since 2.1.0 - */ - def databaseExists(dbName: String): Boolean + /** @inheritdoc */ + override def listFunctions(dbName: String, pattern: String): Dataset[Function] - /** - * Check if the table or view with the specified name exists. This can either be a temporary - * view or a table/view. - * - * @param tableName is either a qualified or unqualified name that designates a table/view. It - * follows the same resolution rule with SQL: search for temp views first then - * table/views in the current database (namespace). - * @since 2.1.0 - */ - def tableExists(tableName: String): Boolean + /** @inheritdoc */ + override def listColumns(tableName: String): Dataset[Column] - /** - * Check if the table or view with the specified name exists in the specified database under the - * Hive Metastore. - * - * To check existence of table/view in other catalogs, please use `tableExists(tableName)` with - * qualified table/view name instead. - * - * @param dbName is an unqualified name that designates a database. - * @param tableName is an unqualified name that designates a table. - * @since 2.1.0 - */ - def tableExists(dbName: String, tableName: String): Boolean + /** @inheritdoc */ + override def listColumns(dbName: String, tableName: String): Dataset[Column] - /** - * Check if the function with the specified name exists. This can either be a temporary function - * or a function. - * - * @param functionName is either a qualified or unqualified name that designates a function. It - * follows the same resolution rule with SQL: search for built-in/temp - * functions first then functions in the current database (namespace). - * @since 2.1.0 - */ - def functionExists(functionName: String): Boolean + /** @inheritdoc */ + override def createTable(tableName: String, path: String): DataFrame - /** - * Check if the function with the specified name exists in the specified database under the - * Hive Metastore. - * - * To check existence of functions in other catalogs, please use `functionExists(functionName)` - * with qualified function name instead. - * - * @param dbName is an unqualified name that designates a database. - * @param functionName is an unqualified name that designates a function. - * @since 2.1.0 - */ - def functionExists(dbName: String, functionName: String): Boolean + /** @inheritdoc */ + override def createTable(tableName: String, path: String, source: String): DataFrame - /** - * Creates a table from the given path and returns the corresponding DataFrame. - * It will use the default data source configured by spark.sql.sources.default. - * - * @param tableName is either a qualified or unqualified name that designates a table. - * If no database identifier is provided, it refers to a table in - * the current database. - * @since 2.0.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable(tableName: String, path: String): DataFrame = { - createTable(tableName, path) - } - - /** - * Creates a table from the given path and returns the corresponding DataFrame. - * It will use the default data source configured by spark.sql.sources.default. - * - * @param tableName is either a qualified or unqualified name that designates a table. - * If no database identifier is provided, it refers to a table in - * the current database. - * @since 2.2.0 - */ - def createTable(tableName: String, path: String): DataFrame - - /** - * Creates a table from the given path based on a data source and returns the corresponding - * DataFrame. - * - * @param tableName is either a qualified or unqualified name that designates a table. - * If no database identifier is provided, it refers to a table in - * the current database. - * @since 2.0.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable(tableName: String, path: String, source: String): DataFrame = { - createTable(tableName, path, source) - } - - /** - * Creates a table from the given path based on a data source and returns the corresponding - * DataFrame. - * - * @param tableName is either a qualified or unqualified name that designates a table. - * If no database identifier is provided, it refers to a table in - * the current database. - * @since 2.2.0 - */ - def createTable(tableName: String, path: String, source: String): DataFrame - - /** - * Creates a table from the given path based on a data source and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName is either a qualified or unqualified name that designates a table. - * If no database identifier is provided, it refers to a table in - * the current database. - * @since 2.0.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable( + /** @inheritdoc */ + override def createTable( tableName: String, source: String, - options: java.util.Map[String, String]): DataFrame = { - createTable(tableName, source, options) - } + options: Map[String, String]): DataFrame - /** - * Creates a table based on the dataset in a data source and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName is either a qualified or unqualified name that designates a table. - * If no database identifier is provided, it refers to a table in - * the current database. - * @since 2.2.0 - */ - def createTable( + /** @inheritdoc */ + override def createTable( tableName: String, source: String, - options: java.util.Map[String, String]): DataFrame = { - createTable(tableName, source, options.asScala.toMap) - } + description: String, + options: Map[String, String]): DataFrame - /** - * (Scala-specific) - * Creates a table from the given path based on a data source and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName is either a qualified or unqualified name that designates a table. - * If no database identifier is provided, it refers to a table in - * the current database. - * @since 2.0.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable( + /** @inheritdoc */ + override def createTable( tableName: String, source: String, - options: Map[String, String]): DataFrame = { - createTable(tableName, source, options) - } + schema: StructType, + options: Map[String, String]): DataFrame - /** - * (Scala-specific) - * Creates a table based on the dataset in a data source and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName is either a qualified or unqualified name that designates a table. - * If no database identifier is provided, it refers to a table in - * the current database. - * @since 2.2.0 - */ - def createTable( + /** @inheritdoc */ + override def createTable( tableName: String, source: String, + schema: StructType, + description: String, options: Map[String, String]): DataFrame - /** - * Create a table from the given path based on a data source, a schema and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName is either a qualified or unqualified name that designates a table. - * If no database identifier is provided, it refers to a table in - * the current database. - * @since 2.0.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable( + /** @inheritdoc */ + override def listCatalogs(): Dataset[CatalogMetadata] + + /** @inheritdoc */ + override def listCatalogs(pattern: String): Dataset[CatalogMetadata] + + /** @inheritdoc */ + override def createExternalTable(tableName: String, path: String): DataFrame = + super.createExternalTable(tableName, path) + + /** @inheritdoc */ + override def createExternalTable(tableName: String, path: String, source: String): DataFrame = + super.createExternalTable(tableName, path, source) + + /** @inheritdoc */ + override def createExternalTable( tableName: String, source: String, - schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - createTable(tableName, source, schema, options) - } + options: util.Map[String, String]): DataFrame = + super.createExternalTable(tableName, source, options) - /** - * Creates a table based on the dataset in a data source and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName is either a qualified or unqualified name that designates a table. - * If no database identifier is provided, it refers to a table in - * the current database. - * @since 3.1.0 - */ - def createTable( + /** @inheritdoc */ + override def createTable( tableName: String, source: String, - description: String, - options: java.util.Map[String, String]): DataFrame = { - createTable( - tableName, - source = source, - description = description, - options = options.asScala.toMap - ) - } + options: util.Map[String, String]): DataFrame = + super.createTable(tableName, source, options) - /** - * (Scala-specific) - * Creates a table based on the dataset in a data source and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName is either a qualified or unqualified name that designates a table. - * If no database identifier is provided, it refers to a table in - * the current database. - * @since 3.1.0 - */ - def createTable( + /** @inheritdoc */ + override def createExternalTable( tableName: String, source: String, - description: String, - options: Map[String, String]): DataFrame + options: Map[String, String]): DataFrame = + super.createExternalTable(tableName, source, options) - /** - * Create a table based on the dataset in a data source, a schema and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName is either a qualified or unqualified name that designates a table. - * If no database identifier is provided, it refers to a table in - * the current database. - * @since 2.2.0 - */ - def createTable( + /** @inheritdoc */ + override def createExternalTable( tableName: String, source: String, schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - createTable(tableName, source, schema, options.asScala.toMap) - } + options: util.Map[String, String]): DataFrame = + super.createExternalTable(tableName, source, schema, options) - /** - * (Scala-specific) - * Create a table from the given path based on a data source, a schema and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName is either a qualified or unqualified name that designates a table. - * If no database identifier is provided, it refers to a table in - * the current database. - * @since 2.0.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable( + /** @inheritdoc */ + override def createTable( tableName: String, source: String, - schema: StructType, - options: Map[String, String]): DataFrame = { - createTable(tableName, source, schema, options) - } + description: String, + options: util.Map[String, String]): DataFrame = + super.createTable(tableName, source, description, options) - /** - * (Scala-specific) - * Create a table based on the dataset in a data source, a schema and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName is either a qualified or unqualified name that designates a table. - * If no database identifier is provided, it refers to a table in - * the current database. - * @since 2.2.0 - */ - def createTable( + /** @inheritdoc */ + override def createTable( tableName: String, source: String, schema: StructType, - options: Map[String, String]): DataFrame + options: util.Map[String, String]): DataFrame = + super.createTable(tableName, source, schema, options) - /** - * Create a table based on the dataset in a data source, a schema and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName is either a qualified or unqualified name that designates a table. - * If no database identifier is provided, it refers to a table in - * the current database. - * @since 3.1.0 - */ - def createTable( + /** @inheritdoc */ + override def createExternalTable( tableName: String, source: String, schema: StructType, - description: String, - options: java.util.Map[String, String]): DataFrame = { - createTable( - tableName, - source = source, - schema = schema, - description = description, - options = options.asScala.toMap - ) - } + options: Map[String, String]): DataFrame = + super.createExternalTable(tableName, source, schema, options) - /** - * (Scala-specific) - * Create a table based on the dataset in a data source, a schema and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName is either a qualified or unqualified name that designates a table. - * If no database identifier is provided, it refers to a table in - * the current database. - * @since 3.1.0 - */ - def createTable( + /** @inheritdoc */ + override def createTable( tableName: String, source: String, schema: StructType, description: String, - options: Map[String, String]): DataFrame - - /** - * Drops the local temporary view with the given view name in the catalog. - * If the view has been cached before, then it will also be uncached. - * - * Local temporary view is session-scoped. Its lifetime is the lifetime of the session that - * created it, i.e. it will be automatically dropped when the session terminates. It's not - * tied to any databases, i.e. we can't use `db1.view1` to reference a local temporary view. - * - * Note that, the return type of this method was Unit in Spark 2.0, but changed to Boolean - * in Spark 2.1. - * - * @param viewName the name of the temporary view to be dropped. - * @return true if the view is dropped successfully, false otherwise. - * @since 2.0.0 - */ - def dropTempView(viewName: String): Boolean - - /** - * Drops the global temporary view with the given view name in the catalog. - * If the view has been cached before, then it will also be uncached. - * - * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark application, - * i.e. it will be automatically dropped when the application terminates. It's tied to a system - * preserved database `global_temp`, and we must use the qualified name to refer a global temp - * view, e.g. `SELECT * FROM global_temp.view1`. - * - * @param viewName the unqualified name of the temporary view to be dropped. - * @return true if the view is dropped successfully, false otherwise. - * @since 2.1.0 - */ - def dropGlobalTempView(viewName: String): Boolean - - /** - * Recovers all the partitions in the directory of a table and update the catalog. - * Only works with a partitioned table, and not a view. - * - * @param tableName is either a qualified or unqualified name that designates a table. - * If no database identifier is provided, it refers to a table in the - * current database. - * @since 2.1.1 - */ - def recoverPartitions(tableName: String): Unit - - /** - * Returns true if the table is currently cached in-memory. - * - * @param tableName is either a qualified or unqualified name that designates a table/view. - * If no database identifier is provided, it refers to a temporary view or - * a table/view in the current database. - * @since 2.0.0 - */ - def isCached(tableName: String): Boolean - - /** - * Caches the specified table in-memory. - * - * @param tableName is either a qualified or unqualified name that designates a table/view. - * If no database identifier is provided, it refers to a temporary view or - * a table/view in the current database. - * @since 2.0.0 - */ - def cacheTable(tableName: String): Unit - - /** - * Caches the specified table with the given storage level. - * - * @param tableName is either a qualified or unqualified name that designates a table/view. - * If no database identifier is provided, it refers to a temporary view or - * a table/view in the current database. - * @param storageLevel storage level to cache table. - * @since 2.3.0 - */ - def cacheTable(tableName: String, storageLevel: StorageLevel): Unit - - - /** - * Removes the specified table from the in-memory cache. - * - * @param tableName is either a qualified or unqualified name that designates a table/view. - * If no database identifier is provided, it refers to a temporary view or - * a table/view in the current database. - * @since 2.0.0 - */ - def uncacheTable(tableName: String): Unit - - /** - * Removes all cached tables from the in-memory cache. - * - * @since 2.0.0 - */ - def clearCache(): Unit - - /** - * Invalidates and refreshes all the cached data and metadata of the given table. For performance - * reasons, Spark SQL or the external data source library it uses might cache certain metadata - * about a table, such as the location of blocks. When those change outside of Spark SQL, users - * should call this function to invalidate the cache. - * - * If this table is cached as an InMemoryRelation, drop the original cached version and make the - * new version cached lazily. - * - * @param tableName is either a qualified or unqualified name that designates a table/view. - * If no database identifier is provided, it refers to a temporary view or - * a table/view in the current database. - * @since 2.0.0 - */ - def refreshTable(tableName: String): Unit - - /** - * Invalidates and refreshes all the cached data (and the associated metadata) for any `Dataset` - * that contains the given data source path. Path matching is by checking for sub-directories, - * i.e. "/" would invalidate everything that is cached and "/test/parent" would invalidate - * everything that is a subdirectory of "/test/parent". - * - * @since 2.0.0 - */ - def refreshByPath(path: String): Unit - - /** - * Returns the current catalog in this session. - * - * @since 3.4.0 - */ - def currentCatalog(): String - - /** - * Sets the current catalog in this session. - * - * @since 3.4.0 - */ - def setCurrentCatalog(catalogName: String): Unit - - /** - * Returns a list of catalogs available in this session. - * - * @since 3.4.0 - */ - def listCatalogs(): Dataset[CatalogMetadata] - - /** - * Returns a list of catalogs which name match the specify pattern and available in this session. - * - * @since 3.5.0 - */ - def listCatalogs(pattern: String): Dataset[CatalogMetadata] + options: util.Map[String, String]): DataFrame = + super.createTable(tableName, source, schema, description, options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala deleted file mode 100644 index 31aee5a43ef47..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalog - -import javax.annotation.Nullable - -import org.apache.spark.annotation.Stable -import org.apache.spark.sql.catalyst.DefinedByConstructorParams - - -// Note: all classes here are expected to be wrapped in Datasets and so must extend -// DefinedByConstructorParams for the catalog to be able to create encoders for them. - -/** - * A catalog in Spark, as returned by the `listCatalogs` method defined in [[Catalog]]. - * - * @param name name of the catalog - * @param description description of the catalog - * @since 3.4.0 - */ -class CatalogMetadata( - val name: String, - @Nullable val description: String) - extends DefinedByConstructorParams { - - override def toString: String = - s"Catalog[name='$name', ${Option(description).map(d => s"description='$d'").getOrElse("")}]" -} - -/** - * A database in Spark, as returned by the `listDatabases` method defined in [[Catalog]]. - * - * @param name name of the database. - * @param catalog name of the catalog that the table belongs to. - * @param description description of the database. - * @param locationUri path (in the form of a uri) to data files. - * @since 2.0.0 - */ -@Stable -class Database( - val name: String, - @Nullable val catalog: String, - @Nullable val description: String, - val locationUri: String) - extends DefinedByConstructorParams { - - def this(name: String, description: String, locationUri: String) = { - this(name, null, description, locationUri) - } - - override def toString: String = { - "Database[" + - s"name='$name', " + - Option(catalog).map { c => s"catalog='$c', " }.getOrElse("") + - Option(description).map { d => s"description='$d', " }.getOrElse("") + - s"path='$locationUri']" - } - -} - - -/** - * A table in Spark, as returned by the `listTables` method in [[Catalog]]. - * - * @param name name of the table. - * @param catalog name of the catalog that the table belongs to. - * @param namespace the namespace that the table belongs to. - * @param description description of the table. - * @param tableType type of the table (e.g. view, table). - * @param isTemporary whether the table is a temporary table. - * @since 2.0.0 - */ -@Stable -class Table( - val name: String, - @Nullable val catalog: String, - @Nullable val namespace: Array[String], - @Nullable val description: String, - val tableType: String, - val isTemporary: Boolean) - extends DefinedByConstructorParams { - - if (namespace != null) { - assert(namespace.forall(_ != null)) - } - - def this( - name: String, - database: String, - description: String, - tableType: String, - isTemporary: Boolean) = { - this(name, null, if (database != null) Array(database) else null, - description, tableType, isTemporary) - } - - def database: String = { - if (namespace != null && namespace.length == 1) namespace(0) else null - } - - override def toString: String = { - "Table[" + - s"name='$name', " + - Option(catalog).map { d => s"catalog='$d', " }.getOrElse("") + - Option(database).map { d => s"database='$d', " }.getOrElse("") + - Option(description).map { d => s"description='$d', " }.getOrElse("") + - s"tableType='$tableType', " + - s"isTemporary='$isTemporary']" - } - -} - - -/** - * A column in Spark, as returned by `listColumns` method in [[Catalog]]. - * - * @param name name of the column. - * @param description description of the column. - * @param dataType data type of the column. - * @param nullable whether the column is nullable. - * @param isPartition whether the column is a partition column. - * @param isBucket whether the column is a bucket column. - * @param isCluster whether the column is a clustering column. - * @since 2.0.0 - */ -@Stable -class Column( - val name: String, - @Nullable val description: String, - val dataType: String, - val nullable: Boolean, - val isPartition: Boolean, - val isBucket: Boolean, - val isCluster: Boolean) - extends DefinedByConstructorParams { - - def this( - name: String, - description: String, - dataType: String, - nullable: Boolean, - isPartition: Boolean, - isBucket: Boolean) = { - this(name, description, dataType, nullable, isPartition, isBucket, isCluster = false) - } - - override def toString: String = { - "Column[" + - s"name='$name', " + - Option(description).map { d => s"description='$d', " }.getOrElse("") + - s"dataType='$dataType', " + - s"nullable='$nullable', " + - s"isPartition='$isPartition', " + - s"isBucket='$isBucket', " + - s"isCluster='$isCluster']" - } - -} - - -/** - * A user-defined function in Spark, as returned by `listFunctions` method in [[Catalog]]. - * - * @param name name of the function. - * @param catalog name of the catalog that the table belongs to. - * @param namespace the namespace that the table belongs to. - * @param description description of the function; description can be null. - * @param className the fully qualified class name of the function. - * @param isTemporary whether the function is a temporary function or not. - * @since 2.0.0 - */ -@Stable -class Function( - val name: String, - @Nullable val catalog: String, - @Nullable val namespace: Array[String], - @Nullable val description: String, - val className: String, - val isTemporary: Boolean) - extends DefinedByConstructorParams { - - if (namespace != null) { - assert(namespace.forall(_ != null)) - } - - def this( - name: String, - database: String, - description: String, - className: String, - isTemporary: Boolean) = { - this(name, null, if (database != null) Array(database) else null, - description, className, isTemporary) - } - - def database: String = { - if (namespace != null && namespace.length == 1) namespace(0) else null - } - - override def toString: String = { - "Function[" + - s"name='$name', " + - Option(database).map { d => s"database='$d', " }.getOrElse("") + - Option(description).map { d => s"description='$d', " }.getOrElse("") + - s"className='$className', " + - s"isTemporary='$isTemporary']" - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala new file mode 100644 index 0000000000000..c7320d350a7ff --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import scala.jdk.CollectionConverters.IteratorHasAsScala + +import org.apache.spark.SparkException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow} +import org.apache.spark.sql.catalyst.plans.logical.{Call, LocalRelation, LogicalPlan, MultiResult} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure +import org.apache.spark.sql.connector.read.{LocalScan, Scan} +import org.apache.spark.util.ArrayImplicits._ + +class InvokeProcedures(session: SparkSession) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case c: Call if c.resolved && c.bound && c.execute && c.checkArgTypes().isSuccess => + session.sessionState.optimizer.execute(c) match { + case Call(ResolvedProcedure(_, _, procedure: BoundProcedure), args, _) => + invoke(procedure, args) + case _ => + throw SparkException.internalError("Unexpected plan for optimized CALL statement") + } + } + + private def invoke(procedure: BoundProcedure, args: Seq[Expression]): LogicalPlan = { + val input = toInternalRow(args) + val scanIterator = procedure.call(input) + val relations = scanIterator.asScala.map(toRelation).toSeq + relations match { + case Nil => LocalRelation(Nil) + case Seq(relation) => relation + case _ => MultiResult(relations) + } + } + + private def toRelation(scan: Scan): LogicalPlan = scan match { + case s: LocalScan => + val attrs = DataTypeUtils.toAttributes(s.readSchema) + val data = s.rows.toImmutableArraySeq + LocalRelation(attrs, data) + case _ => + throw SparkException.internalError( + s"Only local scans are temporarily supported as procedure output: ${scan.getClass.getName}") + } + + private def toInternalRow(args: Seq[Expression]): InternalRow = { + require(args.forall(_.foldable), "args must be foldable") + val values = args.map(_.eval()).toArray + new GenericInternalRow(values) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index d569f1ed484cc..02ad2e79a5645 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, toPrettySQL, ResolveDefaultColumns => DefaultCols} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, LookupCatalog, SupportsNamespaces, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, DelegatingCatalogExtension, LookupCatalog, SupportsNamespaces, V1Table} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command._ @@ -284,10 +284,20 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) case AnalyzeColumn(ResolvedV1TableOrViewIdentifier(ident), columnNames, allColumns) => AnalyzeColumnCommand(ident, columnNames, allColumns) - case RepairTable(ResolvedV1TableIdentifier(ident), addPartitions, dropPartitions) => + // V2 catalog doesn't support REPAIR TABLE yet, we must use v1 command here. + case RepairTable( + ResolvedV1TableIdentifierInSessionCatalog(ident), + addPartitions, + dropPartitions) => RepairTableCommand(ident, addPartitions, dropPartitions) - case LoadData(ResolvedV1TableIdentifier(ident), path, isLocal, isOverwrite, partition) => + // V2 catalog doesn't support LOAD DATA yet, we must use v1 command here. + case LoadData( + ResolvedV1TableIdentifierInSessionCatalog(ident), + path, + isLocal, + isOverwrite, + partition) => LoadDataCommand( ident, path, @@ -336,7 +346,8 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) } ShowColumnsCommand(db, v1TableName, output) - case RecoverPartitions(ResolvedV1TableIdentifier(ident)) => + // V2 catalog doesn't support RECOVER PARTITIONS yet, we must use v1 command here. + case RecoverPartitions(ResolvedV1TableIdentifierInSessionCatalog(ident)) => RepairTableCommand( ident, enableAddPartitions = true, @@ -364,8 +375,9 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) purge, retainData = false) + // V2 catalog doesn't support setting serde properties yet, we must use v1 command here. case SetTableSerDeProperties( - ResolvedV1TableIdentifier(ident), + ResolvedV1TableIdentifierInSessionCatalog(ident), serdeClassName, serdeProperties, partitionSpec) => @@ -380,10 +392,10 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) // V2 catalog doesn't support setting partition location yet, we must use v1 command here. case SetTableLocation( - ResolvedTable(catalog, _, t: V1Table, _), + ResolvedV1TableIdentifierInSessionCatalog(ident), Some(partitionSpec), - location) if isSessionCatalog(catalog) => - AlterTableSetLocationCommand(t.v1Table.identifier, Some(partitionSpec), location) + location) => + AlterTableSetLocationCommand(ident, Some(partitionSpec), location) case AlterViewAs(ResolvedViewIdentifier(ident), originalText, query) => AlterViewAsCommand(ident, originalText, query) @@ -600,6 +612,14 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) } } + object ResolvedV1TableIdentifierInSessionCatalog { + def unapply(resolved: LogicalPlan): Option[TableIdentifier] = resolved match { + case ResolvedTable(catalog, _, t: V1Table, _) if isSessionCatalog(catalog) => + Some(t.catalogTable.identifier) + case _ => None + } + } + object ResolvedV1TableOrViewIdentifier { def unapply(resolved: LogicalPlan): Option[TableIdentifier] = resolved match { case ResolvedV1TableIdentifier(ident) => Some(ident) @@ -684,7 +704,8 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) } private def supportsV1Command(catalog: CatalogPlugin): Boolean = { - isSessionCatalog(catalog) && - SQLConf.get.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isEmpty + isSessionCatalog(catalog) && ( + SQLConf.get.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isEmpty || + catalog.isInstanceOf[DelegatingCatalogExtension]) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTranspose.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTranspose.scala new file mode 100644 index 0000000000000..d71237ca15ec3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTranspose.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, IsNotNull, Literal, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Limit, LogicalPlan, Project, Sort, Transpose} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{AtomicType, DataType, StringType} +import org.apache.spark.unsafe.types.UTF8String + + +/** + * Rule that resolves and transforms an `UnresolvedTranspose` logical plan into a `Transpose` + * logical plan, which effectively transposes a DataFrame by turning rows into columns based + * on a specified index column. + * + * The high-level logic for the transpose operation is as follows: + * - If the index column is not provided, the first column of the DataFrame is used as the + * default index column. + * - The index column is cast to `StringType` to ensure consistent column naming. + * - Non-index columns are cast to a common data type, determined by finding the least + * common type that can accommodate all non-index columns. + * - The data is sorted by the index column, and rows with `null` index values are excluded + * from the transpose operation. + * - The transposed DataFrame is constructed by turning the original rows into columns, with + * the index column values becoming the new column names and the non-index column values + * populating the transposed data. + */ +class ResolveTranspose(sparkSession: SparkSession) extends Rule[LogicalPlan] { + + private def leastCommonType(dataTypes: Seq[DataType]): DataType = { + if (dataTypes.isEmpty) { + StringType + } else { + dataTypes.reduce { (dt1, dt2) => + AnsiTypeCoercion.findTightestCommonType(dt1, dt2).getOrElse { + throw new AnalysisException( + errorClass = "TRANSPOSE_NO_LEAST_COMMON_TYPE", + messageParameters = Map( + "dt1" -> dt1.sql, + "dt2" -> dt2.sql) + ) + } + } + } + } + + private def transposeMatrix( + fullCollectedRows: Array[InternalRow], + nonIndexColumnNames: Seq[String], + nonIndexColumnDataTypes: Seq[DataType]): Array[Array[Any]] = { + val numTransposedRows = fullCollectedRows.head.numFields - 1 + val numTransposedCols = fullCollectedRows.length + 1 + val finalMatrix = Array.ofDim[Any](numTransposedRows, numTransposedCols) + + // Example of the original DataFrame: + // +---+-----+-----+ + // | id|col1 |col2 | + // +---+-----+-----+ + // | 1| 10 | 20 | + // | 2| 30 | 40 | + // +---+-----+-----+ + // + // After transposition, the finalMatrix will look like: + // [ + // ["col1", 10, 30], // Transposed row for col1 + // ["col2", 20, 40] // Transposed row for col2 + // ] + + for (i <- 0 until numTransposedRows) { + // Insert non-index column name as the first element in each transposed row + finalMatrix(i)(0) = UTF8String.fromString(nonIndexColumnNames(i)) + + for (j <- 1 until numTransposedCols) { + // Insert the transposed data + + // Example: If j = 2, then row = fullCollectedRows(1) + // This corresponds to the second row of the original DataFrame: InternalRow(2, 30, 40) + val row = fullCollectedRows(j - 1) + + // Example: If i = 0 (for "col1"), and j = 2, + // then finalMatrix(0)(2) corresponds to row.get(1, nonIndexColumnDataTypes(0)), + // which accesses the value 30 from InternalRow(2, 30, 40) + finalMatrix(i)(j) = row.get(i + 1, nonIndexColumnDataTypes(i)) + } + } + finalMatrix + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( + _.containsPattern(TreePattern.UNRESOLVED_TRANSPOSE)) { + case t @ UnresolvedTranspose(indices, child) if child.resolved && indices.forall(_.resolved) => + assert(indices.length == 0 || indices.length == 1, + "The number of index columns should be either 0 or 1.") + + // Handle empty frame with no column headers + if (child.output.isEmpty) { + return Transpose(Seq.empty) + } + + // Use the first column as index column if not provided + val inferredIndexColumn = if (indices.isEmpty) { + child.output.head + } else { + indices.head + } + + // Cast the index column to StringType + val indexColumnAsString = inferredIndexColumn match { + case attr: Attribute if attr.dataType.isInstanceOf[AtomicType] => + Alias(Cast(attr, StringType), attr.name)() + case attr: Attribute => + throw new AnalysisException( + errorClass = "TRANSPOSE_INVALID_INDEX_COLUMN", + messageParameters = Map( + "reason" -> s"Index column must be of atomic type, but found: ${attr.dataType}") + ) + case _ => + throw new AnalysisException( + errorClass = "TRANSPOSE_INVALID_INDEX_COLUMN", + messageParameters = Map( + "reason" -> s"Index column must be an atomic attribute") + ) + } + + // Cast non-index columns to the least common type + val nonIndexColumns = child.output.filterNot( + _.exprId == inferredIndexColumn.asInstanceOf[Attribute].exprId) + val nonIndexTypes = nonIndexColumns.map(_.dataType) + val commonType = leastCommonType(nonIndexTypes) + val nonIndexColumnsAsLCT = nonIndexColumns.map { attr => + Alias(Cast(attr, commonType), attr.name)() + } + + // Exclude nulls and sort index column values, and collect the casted frame + val allCastCols = indexColumnAsString +: nonIndexColumnsAsLCT + val nonNullChild = Filter(IsNotNull(inferredIndexColumn), child) + val sortedChild = Sort( + Seq(SortOrder(inferredIndexColumn, Ascending)), + global = true, + nonNullChild + ) + val projectAllCastCols = Project(allCastCols, sortedChild) + val maxValues = sparkSession.sessionState.conf.dataFrameTransposeMaxValues + val limit = Literal(maxValues + 1) + val limitedProject = Limit(limit, projectAllCastCols) + val queryExecution = sparkSession.sessionState.executePlan(limitedProject) + val fullCollectedRows = queryExecution.executedPlan.executeCollect() + + if (fullCollectedRows.isEmpty) { + // Return a DataFrame with a single column "key" containing non-index column names + val keyAttr = AttributeReference("key", StringType, nullable = false)() + val keyValues = nonIndexColumns.map(col => UTF8String.fromString(col.name)) + val keyRows = keyValues.map(value => InternalRow(value)) + + Transpose(Seq(keyAttr), keyRows) + } else { + if (fullCollectedRows.length > maxValues) { + throw new AnalysisException( + errorClass = "TRANSPOSE_EXCEED_ROW_LIMIT", + messageParameters = Map( + "maxValues" -> maxValues.toString, + "config" -> SQLConf.DATAFRAME_TRANSPOSE_MAX_VALUES.key)) + } + + // Transpose the matrix + val nonIndexColumnNames = nonIndexColumns.map(_.name) + val nonIndexColumnDataTypes = projectAllCastCols.output.tail.map(attr => attr.dataType) + val transposedMatrix = transposeMatrix( + fullCollectedRows, nonIndexColumnNames, nonIndexColumnDataTypes) + val transposedInternalRows = transposedMatrix.map { row => + InternalRow.fromSeq(row.toIndexedSeq) + } + + // Construct output attributes + val keyAttr = AttributeReference("key", StringType, nullable = false)() + val transposedColumnNames = fullCollectedRows.map { row => row.getString(0) } + val valueAttrs = transposedColumnNames.map { value => + AttributeReference( + value, + commonType + )() + } + + val transposeOutput = (keyAttr +: valueAttrs).toIndexedSeq + val transposeData = transposedInternalRows.toIndexedSeq + Transpose(transposeOutput, transposeData) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala new file mode 100644 index 0000000000000..af91b57a6848b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.classic + +import scala.language.implicitConversions + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql._ + +/** + * Conversions from sql interfaces to the Classic specific implementation. + * + * This class is mainly used by the implementation, but is also meant to be used by extension + * developers. + * + * We provide both a trait and an object. The trait is useful in situations where an extension + * developer needs to use these conversions in a project covering multiple Spark versions. They can + * create a shim for these conversions, the Spark 4+ version of the shim implements this trait, and + * shims for older versions do not. + */ +@DeveloperApi +trait ClassicConversions { + implicit def castToImpl(session: api.SparkSession): SparkSession = + session.asInstanceOf[SparkSession] + + implicit def castToImpl[T](ds: api.Dataset[T]): Dataset[T] = + ds.asInstanceOf[Dataset[T]] + + implicit def castToImpl(rgds: api.RelationalGroupedDataset): RelationalGroupedDataset = + rgds.asInstanceOf[RelationalGroupedDataset] + + implicit def castToImpl[K, V](kvds: api.KeyValueGroupedDataset[K, V]) + : KeyValueGroupedDataset[K, V] = kvds.asInstanceOf[KeyValueGroupedDataset[K, V]] +} + +object ClassicConversions extends ClassicConversions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index aae424afcb8ac..1bf6f4e4d7d9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -474,7 +474,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { // Bucketed scan only has one time overhead but can have multi-times benefits in cache, // so we always do bucketed scan in a cached plan. var disableConfigs = Seq(SQLConf.AUTO_BUCKETED_SCAN_ENABLED) - if (!session.conf.get(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING)) { + if (!session.sessionState.conf.getConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING)) { // Allowing changing cached plan output partitioning might lead to regression as it introduces // extra shuffle disableConfigs = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/EmptyRelationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/EmptyRelationExec.scala index 8a544de7567e8..a0c3d7b51c2c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/EmptyRelationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/EmptyRelationExec.scala @@ -71,13 +71,15 @@ case class EmptyRelationExec(@transient logical: LogicalPlan) extends LeafExecNo maxFields, printNodeId, indent) - lastChildren.add(true) - logical.generateTreeString( - depth + 1, lastChildren, append, verbose, "", false, maxFields, printNodeId, indent) - lastChildren.remove(lastChildren.size() - 1) + Option(logical).foreach { _ => + lastChildren.add(true) + logical.generateTreeString( + depth + 1, lastChildren, append, verbose, "", false, maxFields, printNodeId, indent) + lastChildren.remove(lastChildren.size() - 1) + } } override def doCanonicalize(): SparkPlan = { - this.copy(logical = LocalRelation(logical.output).canonicalized) + this.copy(logical = LocalRelation(output).canonicalized) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala new file mode 100644 index 0000000000000..c2b12b053c927 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class MultiResultExec(children: Seq[SparkPlan]) extends SparkPlan { + + override def output: Seq[Attribute] = children.lastOption.map(_.output).getOrElse(Nil) + + override protected def doExecute(): RDD[InternalRow] = { + children.lastOption.map(_.execute()).getOrElse(sparkContext.emptyRDD) + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[SparkPlan]): MultiResultExec = { + copy(children = newChildren) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 58fff2d4a1a29..5db14a8662138 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -44,7 +44,7 @@ object SQLExecution extends Logging { private def nextExecutionId: Long = _nextExecutionId.getAndIncrement - private val executionIdToQueryExecution = new ConcurrentHashMap[Long, QueryExecution]() + private[sql] val executionIdToQueryExecution = new ConcurrentHashMap[Long, QueryExecution]() def getQueryExecution(executionId: Long): QueryExecution = { executionIdToQueryExecution.get(executionId) @@ -52,6 +52,9 @@ object SQLExecution extends Logging { private val testing = sys.props.contains(IS_TESTING.key) + private[sql] def executionIdJobTag(session: SparkSession, id: Long) = + s"${session.sessionJobTag}-execution-root-id-$id" + private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = { val sc = sparkSession.sparkContext // only throw an exception during tests. a missing execution ID should not fail a job. @@ -82,12 +85,13 @@ object SQLExecution extends Logging { // And for the root execution, rootExecutionId == executionId. if (sc.getLocalProperty(EXECUTION_ROOT_ID_KEY) == null) { sc.setLocalProperty(EXECUTION_ROOT_ID_KEY, executionId.toString) + sc.addJobTag(executionIdJobTag(sparkSession, executionId)) } val rootExecutionId = sc.getLocalProperty(EXECUTION_ROOT_ID_KEY).toLong executionIdToQueryExecution.put(executionId, queryExecution) val originalInterruptOnCancel = sc.getLocalProperty(SPARK_JOB_INTERRUPT_ON_CANCEL) if (originalInterruptOnCancel == null) { - val interruptOnCancel = sparkSession.conf.get(SQLConf.INTERRUPT_ON_CANCEL) + val interruptOnCancel = sparkSession.sessionState.conf.getConf(SQLConf.INTERRUPT_ON_CANCEL) sc.setInterruptOnCancel(interruptOnCancel) } try { @@ -116,92 +120,94 @@ object SQLExecution extends Logging { val redactedConfigs = sparkSession.sessionState.conf.redactOptions(modifiedConfigs) withSQLConfPropagated(sparkSession) { - var ex: Option[Throwable] = None - var isExecutedPlanAvailable = false - val startTime = System.nanoTime() - val startEvent = SparkListenerSQLExecutionStart( - executionId = executionId, - rootExecutionId = Some(rootExecutionId), - description = desc, - details = callSite.longForm, - physicalPlanDescription = "", - sparkPlanInfo = SparkPlanInfo.EMPTY, - time = System.currentTimeMillis(), - modifiedConfigs = redactedConfigs, - jobTags = sc.getJobTags(), - jobGroupId = Option(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)) - ) - try { - body match { - case Left(e) => - sc.listenerBus.post(startEvent) + withSessionTagsApplied(sparkSession) { + var ex: Option[Throwable] = None + var isExecutedPlanAvailable = false + val startTime = System.nanoTime() + val startEvent = SparkListenerSQLExecutionStart( + executionId = executionId, + rootExecutionId = Some(rootExecutionId), + description = desc, + details = callSite.longForm, + physicalPlanDescription = "", + sparkPlanInfo = SparkPlanInfo.EMPTY, + time = System.currentTimeMillis(), + modifiedConfigs = redactedConfigs, + jobTags = sc.getJobTags(), + jobGroupId = Option(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)) + ) + try { + body match { + case Left(e) => + sc.listenerBus.post(startEvent) + throw e + case Right(f) => + val planDescriptionMode = + ExplainMode.fromString(sparkSession.sessionState.conf.uiExplainMode) + val planDesc = queryExecution.explainString(planDescriptionMode) + val planInfo = try { + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan) + } catch { + case NonFatal(e) => + logDebug("Failed to generate SparkPlanInfo", e) + // If the queryExecution already failed before this, we are not able to generate + // the the plan info, so we use and empty graphviz node to make the UI happy + SparkPlanInfo.EMPTY + } + sc.listenerBus.post( + startEvent.copy(physicalPlanDescription = planDesc, sparkPlanInfo = planInfo)) + isExecutedPlanAvailable = true + f() + } + } catch { + case e: Throwable => + ex = Some(e) throw e - case Right(f) => - val planDescriptionMode = - ExplainMode.fromString(sparkSession.sessionState.conf.uiExplainMode) - val planDesc = queryExecution.explainString(planDescriptionMode) - val planInfo = try { - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan) - } catch { - case NonFatal(e) => - logDebug("Failed to generate SparkPlanInfo", e) - // If the queryExecution already failed before this, we are not able to generate - // the the plan info, so we use and empty graphviz node to make the UI happy - SparkPlanInfo.EMPTY - } - sc.listenerBus.post( - startEvent.copy(physicalPlanDescription = planDesc, sparkPlanInfo = planInfo)) - isExecutedPlanAvailable = true - f() - } - } catch { - case e: Throwable => - ex = Some(e) - throw e - } finally { - val endTime = System.nanoTime() - val errorMessage = ex.map { - case e: SparkThrowable => - SparkThrowableHelper.getMessage(e, ErrorMessageFormat.PRETTY) - case e => - Utils.exceptionString(e) - } - if (queryExecution.shuffleCleanupMode != DoNotCleanup - && isExecutedPlanAvailable) { - val shuffleIds = queryExecution.executedPlan match { - case ae: AdaptiveSparkPlanExec => - ae.context.shuffleIds.asScala.keys - case _ => - Iterable.empty + } finally { + val endTime = System.nanoTime() + val errorMessage = ex.map { + case e: SparkThrowable => + SparkThrowableHelper.getMessage(e, ErrorMessageFormat.PRETTY) + case e => + Utils.exceptionString(e) } - shuffleIds.foreach { shuffleId => - queryExecution.shuffleCleanupMode match { - case RemoveShuffleFiles => - // Same as what we do in ContextCleaner.doCleanupShuffle, but do not unregister - // the shuffle on MapOutputTracker, so that stage retries would be triggered. - // Set blocking to Utils.isTesting to deflake unit tests. - sc.shuffleDriverComponents.removeShuffle(shuffleId, Utils.isTesting) - case SkipMigration => - SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId) - case _ => // this should not happen + if (queryExecution.shuffleCleanupMode != DoNotCleanup + && isExecutedPlanAvailable) { + val shuffleIds = queryExecution.executedPlan match { + case ae: AdaptiveSparkPlanExec => + ae.context.shuffleIds.asScala.keys + case _ => + Iterable.empty + } + shuffleIds.foreach { shuffleId => + queryExecution.shuffleCleanupMode match { + case RemoveShuffleFiles => + // Same as what we do in ContextCleaner.doCleanupShuffle, but do not unregister + // the shuffle on MapOutputTracker, so that stage retries would be triggered. + // Set blocking to Utils.isTesting to deflake unit tests. + sc.shuffleDriverComponents.removeShuffle(shuffleId, Utils.isTesting) + case SkipMigration => + SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId) + case _ => // this should not happen + } } } + val event = SparkListenerSQLExecutionEnd( + executionId, + System.currentTimeMillis(), + // Use empty string to indicate no error, as None may mean events generated by old + // versions of Spark. + errorMessage.orElse(Some(""))) + // Currently only `Dataset.withAction` and `DataFrameWriter.runCommand` specify the + // `name` parameter. The `ExecutionListenerManager` only watches SQL executions with + // name. We can specify the execution name in more places in the future, so that + // `QueryExecutionListener` can track more cases. + event.executionName = name + event.duration = endTime - startTime + event.qe = queryExecution + event.executionFailure = ex + sc.listenerBus.post(event) } - val event = SparkListenerSQLExecutionEnd( - executionId, - System.currentTimeMillis(), - // Use empty string to indicate no error, as None may mean events generated by old - // versions of Spark. - errorMessage.orElse(Some(""))) - // Currently only `Dataset.withAction` and `DataFrameWriter.runCommand` specify the `name` - // parameter. The `ExecutionListenerManager` only watches SQL executions with name. We - // can specify the execution name in more places in the future, so that - // `QueryExecutionListener` can track more cases. - event.executionName = name - event.duration = endTime - startTime - event.qe = queryExecution - event.executionFailure = ex - sc.listenerBus.post(event) } } } finally { @@ -211,6 +217,7 @@ object SQLExecution extends Logging { // The current execution is the root execution if rootExecutionId == executionId. if (sc.getLocalProperty(EXECUTION_ROOT_ID_KEY) == executionId.toString) { sc.setLocalProperty(EXECUTION_ROOT_ID_KEY, null) + sc.removeJobTag(executionIdJobTag(sparkSession, executionId)) } sc.setLocalProperty(SPARK_JOB_INTERRUPT_ON_CANCEL, originalInterruptOnCancel) } @@ -238,15 +245,28 @@ object SQLExecution extends Logging { val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) withSQLConfPropagated(sparkSession) { - try { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) - body - } finally { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + withSessionTagsApplied(sparkSession) { + try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) + body + } finally { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + } } } } + private[sql] def withSessionTagsApplied[T](sparkSession: SparkSession)(block: => T): T = { + val allTags = sparkSession.managedJobTags.values().asScala.toSet + sparkSession.sessionJobTag + sparkSession.sparkContext.addJobTags(allTags) + + try { + block + } finally { + sparkSession.sparkContext.removeJobTags(allTags) + } + } + /** * Wrap an action with specified SQL configs. These configs will be propagated to the executor * side via job local properties. @@ -286,10 +306,13 @@ object SQLExecution extends Logging { val originalSession = SparkSession.getActiveSession val originalLocalProps = sc.getLocalProperties SparkSession.setActiveSession(activeSession) - sc.setLocalProperties(localProps) - val res = body - // reset active session and local props. - sc.setLocalProperties(originalLocalProps) + val res = withSessionTagsApplied(activeSession) { + sc.setLocalProperties(localProps) + val res = body + // reset active session and local props. + sc.setLocalProperties(originalLocalProps) + res + } if (originalSession.nonEmpty) { SparkSession.setActiveSession(originalSession.get) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 8f27a0e8f673d..a8261e5d98ba0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreePattern.PARAMETER import org.apache.spark.sql.catalyst.util.DateTimeConstants import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryParsingErrors} import org.apache.spark.sql.execution.command._ @@ -466,22 +465,6 @@ class SparkSqlAstBuilder extends AstBuilder { } } - - private def checkInvalidParameter(plan: LogicalPlan, statement: String): - Unit = { - plan.foreach { p => - p.expressions.foreach { expr => - if (expr.containsPattern(PARAMETER)) { - throw QueryParsingErrors.parameterMarkerNotAllowed(statement, p.origin) - } - } - } - plan.children.foreach(p => checkInvalidParameter(p, statement)) - plan.innerChildren.collect { - case child: LogicalPlan => checkInvalidParameter(child, statement) - } - } - /** * Create or replace a view. This creates a [[CreateViewCommand]]. * @@ -537,12 +520,13 @@ class SparkSqlAstBuilder extends AstBuilder { } val qPlan: LogicalPlan = plan(ctx.query) - // Disallow parameter markers in the body of the view. + // Disallow parameter markers in the query of the view. // We need this limitation because we store the original query text, pre substitution. - // To lift this we would need to reconstitute the body with parameter markers replaced with the + // To lift this we would need to reconstitute the query with parameter markers replaced with the // values given at CREATE VIEW time, or we would need to store the parameter values alongside // the text. - checkInvalidParameter(qPlan, "CREATE VIEW body") + // The same rule can be found in CACHE TABLE builder. + checkInvalidParameter(qPlan, "the query of CREATE VIEW") if (viewType == PersistedView) { val originalText = source(ctx.query) assert(Option(originalText).isDefined, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6d940a30619fb..aee735e48fc5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -1041,6 +1041,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case WriteFiles(child, fileFormat, partitionColumns, bucket, options, staticPartitions) => WriteFilesExec(planLater(child), fileFormat, partitionColumns, bucket, options, staticPartitions) :: Nil + case MultiResult(children) => + MultiResultExec(children.map(planLater)) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 3832d73044078..09d9915022a65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -504,7 +504,7 @@ case class ScalaAggregator[IN, BUF, OUT]( private[this] lazy val inputDeserializer = inputEncoder.createDeserializer() private[this] lazy val bufferSerializer = bufferEncoder.createSerializer() private[this] lazy val bufferDeserializer = bufferEncoder.createDeserializer() - private[this] lazy val outputEncoder = agg.outputEncoder.asInstanceOf[ExpressionEncoder[OUT]] + private[this] lazy val outputEncoder = encoderFor(agg.outputEncoder) private[this] lazy val outputSerializer = outputEncoder.createSerializer() def dataType: DataType = outputEncoder.objSerializer.dataType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index 5a9adf8ab553d..91454c79df600 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -330,7 +330,7 @@ object CommandUtils extends Logging { val attributePercentiles = mutable.HashMap[Attribute, ArrayData]() if (attrsToGenHistogram.nonEmpty) { val percentiles = (0 to conf.histogramNumBins) - .map(i => i.toDouble / conf.histogramNumBins).toArray + .map(i => i.toDouble / conf.histogramNumBins).toArray[Any] val namedExprs = attrsToGenHistogram.map { attr => val aggFunc = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index ea2736b2c1266..ea9d53190546e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, SupervisingCommand} +import org.apache.spark.sql.catalyst.plans.logical.{Command, ExecutableDuringAnalysis, LogicalPlan, SupervisingCommand} import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike} import org.apache.spark.sql.connector.ExternalCommandRunner import org.apache.spark.sql.execution.{CommandExecutionMode, ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode} @@ -165,14 +165,19 @@ case class ExplainCommand( // Run through the optimizer to generate the physical plan. override def run(sparkSession: SparkSession): Seq[Row] = try { - val outputString = sparkSession.sessionState.executePlan(logicalPlan, CommandExecutionMode.SKIP) - .explainString(mode) + val stagedLogicalPlan = stageForAnalysis(logicalPlan) + val qe = sparkSession.sessionState.executePlan(stagedLogicalPlan, CommandExecutionMode.SKIP) + val outputString = qe.explainString(mode) Seq(Row(outputString)) } catch { case NonFatal(cause) => ("Error occurred during query planning: \n" + cause.getMessage).split("\n") .map(Row(_)).toImmutableArraySeq } + private def stageForAnalysis(plan: LogicalPlan): LogicalPlan = plan transform { + case p: ExecutableDuringAnalysis => p.stageForExplain() + } + def withTransformedSupervisedPlan(transformer: LogicalPlan => LogicalPlan): LogicalPlan = copy(logicalPlan = transformer(logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 3f221bfa53051..814e56b204f9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -861,7 +861,7 @@ case class RepairTableCommand( // Hive metastore may not have enough memory to handle millions of partitions in single RPC, // we should split them into smaller batches. Since Hive client is not thread safe, we cannot // do this in parallel. - val batchSize = spark.conf.get(SQLConf.ADD_PARTITION_BATCH_SIZE) + val batchSize = spark.sessionState.conf.getConf(SQLConf.ADD_PARTITION_BATCH_SIZE) partitionSpecsAndLocs.iterator.grouped(batchSize).foreach { batch => val now = MILLISECONDS.toSeconds(System.currentTimeMillis()) val parts = batch.map { case (spec, location) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index e1061a46db7b0..071e3826b20a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -135,7 +135,7 @@ case class CreateViewCommand( referredTempFunctions) catalog.createTempView(name.table, tableDefinition, overrideIfExists = replace) } else if (viewType == GlobalTempView) { - val db = sparkSession.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE) + val db = sparkSession.sessionState.conf.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE) val viewIdent = TableIdentifier(name.table, Option(db)) val aliasedPlan = aliasPlan(sparkSession, analyzedPlan) val tableDefinition = createTemporaryViewRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala index 4fa1e0c1f2c58..fd47feef25d57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.util.SchemaUtils object BucketingUtils { // The file name of bucketed data should have 3 parts: @@ -53,10 +54,7 @@ object BucketingUtils { bucketIdGenerator(mutableInternalRow).getInt(0) } - def canBucketOn(dataType: DataType): Boolean = dataType match { - case st: StringType => st.supportsBinaryOrdering - case other => true - } + def canBucketOn(dataType: DataType): Boolean = !SchemaUtils.hasNonUTF8BinaryCollation(dataType) def bucketIdToString(id: Int): String = f"_$id%05d" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index d88b5ee8877d7..968c204841e46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -267,7 +267,8 @@ case class DataSource( checkAndGlobPathIfNecessary(checkEmptyGlobPath = false, checkFilesExist = false) createInMemoryFileIndex(globbedPaths) }) - val forceNullable = sparkSession.conf.get(SQLConf.FILE_SOURCE_SCHEMA_FORCE_NULLABLE) + val forceNullable = sparkSession.sessionState.conf + .getConf(SQLConf.FILE_SOURCE_SCHEMA_FORCE_NULLABLE) val sourceDataSchema = if (forceNullable) dataSchema.asNullable else dataSchema SourceInfo( s"FileSource[$path]", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 1dd2659a1b169..2be4b236872f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoDir, I import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.catalyst.util.{GeneratedColumn, ResolveDefaultColumns, V2ExpressionBuilder} +import org.apache.spark.sql.catalyst.util.{GeneratedColumn, IdentityColumn, ResolveDefaultColumns, V2ExpressionBuilder} import org.apache.spark.sql.connector.catalog.{SupportsRead, V1Table} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue} @@ -146,6 +146,11 @@ object DataSourceAnalysis extends Rule[LogicalPlan] { tableDesc.identifier, "generated columns") } + if (IdentityColumn.hasIdentityColumns(newSchema)) { + throw QueryCompilationErrors.unsupportedTableOperationError( + tableDesc.identifier, "identity columns") + } + val newTableDesc = tableDesc.copy(schema = newSchema) CreateDataSourceTableCommand(newTableDesc, ignoreIfExists = mode == SaveMode.Ignore) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 676a2ab64d0a3..ffdca65151052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -29,7 +29,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.Path -import org.apache.spark.SparkRuntimeException +import org.apache.spark.{SparkException, SparkRuntimeException} import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -550,7 +550,7 @@ object PartitioningUtils extends SQLConfHelper { Cast(Literal(unescapePathName(value)), it).eval() case BinaryType => value.getBytes() case BooleanType => value.toBoolean - case dt => throw QueryExecutionErrors.typeUnsupportedError(dt) + case dt => throw SparkException.internalError(s"Unsupported partition type: $dt") } def validatePartitionColumn( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 5423232db4293..e44f1d35e9cdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -19,12 +19,14 @@ package org.apache.spark.sql.execution.datasources import scala.util.control.NonFatal +import org.apache.spark.SparkThrowable import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, CTERelationDef, LogicalPlan, WithCTE} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.LeafRunnableCommand -import org.apache.spark.sql.sources.CreatableRelationProvider +import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider} /** * Saves the results of `query` in to a data source. @@ -44,8 +46,26 @@ case class SaveIntoDataSourceCommand( override def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { - val relation = dataSource.createRelation( - sparkSession.sqlContext, mode, options, Dataset.ofRows(sparkSession, query)) + var relation: BaseRelation = null + + try { + relation = dataSource.createRelation( + sparkSession.sqlContext, mode, options, Dataset.ofRows(sparkSession, query)) + } catch { + case e: SparkThrowable => + // We should avoid wrapping `SparkThrowable` exceptions into another `AnalysisException`. + throw e + case e @ (_: NullPointerException | _: MatchError | _: ArrayIndexOutOfBoundsException) => + // These are some of the exceptions thrown by the data source API. We catch these + // exceptions here and rethrow QueryCompilationErrors.externalDataSourceException to + // provide a more friendly error message for the user. This list is not exhaustive. + throw QueryCompilationErrors.externalDataSourceException(e) + case e: Throwable => + // For other exceptions, just rethrow it, since we don't have enough information to + // provide a better error message for the user at the moment. We may want to further + // improve the error message handling in the future. + throw e + } try { val logicalRelation = LogicalRelation(relation, toAttributes(relation.schema), None, false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala index cbff526592f92..54c100282e2db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala @@ -98,7 +98,7 @@ class BinaryFileFormat extends FileFormat with DataSourceRegister { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) val filterFuncs = filters.flatMap(filter => createFilterFunction(filter)) - val maxLength = sparkSession.conf.get(SOURCES_BINARY_FILE_MAX_LENGTH) + val maxLength = sparkSession.sessionState.conf.getConf(SOURCES_BINARY_FILE_MAX_LENGTH) file: PartitionedFile => { val path = file.toPath diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index fc6cba786c4ed..d9367d92d462e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -115,7 +115,7 @@ case class CreateTempViewUsing( }.logicalPlan if (global) { - val db = sparkSession.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE) + val db = sparkSession.sessionState.conf.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE) val viewIdent = TableIdentifier(tableIdent.table, Option(db)) val viewDefinition = createTemporaryViewRelation( viewIdent, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 112ee2c5450b2..76cd33b815edd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.{toPrettySQL, GeneratedColumn, ResolveDefaultColumns, V2ExpressionBuilder} +import org.apache.spark.sql.catalyst.util.{toPrettySQL, GeneratedColumn, IdentityColumn, ResolveDefaultColumns, V2ExpressionBuilder} import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, SupportsDeleteV2, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog, TruncatableTable} import org.apache.spark.sql.connector.catalog.index.SupportsIndex import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} @@ -185,6 +185,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat val statementType = "CREATE TABLE" GeneratedColumn.validateGeneratedColumns( c.tableSchema, catalog.asTableCatalog, ident, statementType) + IdentityColumn.validateIdentityColumn(c.tableSchema, catalog.asTableCatalog, ident) CreateTableExec( catalog.asTableCatalog, @@ -214,6 +215,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat val statementType = "REPLACE TABLE" GeneratedColumn.validateGeneratedColumns( c.tableSchema, catalog.asTableCatalog, ident, statementType) + IdentityColumn.validateIdentityColumn(c.tableSchema, catalog.asTableCatalog, ident) val v2Columns = columns.map(_.toV2Column(statementType)).toArray catalog match { @@ -551,6 +553,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat systemScope, pattern) :: Nil + case c: Call => + ExplainOnlySparkPlan(c) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala new file mode 100644 index 0000000000000..bbf56eaa71184 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.LeafLike +import org.apache.spark.sql.execution.SparkPlan + +case class ExplainOnlySparkPlan(toExplain: LogicalPlan) extends SparkPlan with LeafLike[SparkPlan] { + + override def output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + toExplain.simpleString(maxFields) + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index f56f9436d9437..4eee731e0b2d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -146,6 +146,19 @@ abstract class FileTable( val entry = options.get(DataSource.GLOB_PATHS_KEY) Option(entry).map(_ == "true").getOrElse(true) } + + /** + * Merge the options of FileTable and the table operation while respecting the + * keys of the table operation. + * + * @param options The options of the table operation. + * @return + */ + protected def mergedOptions(options: CaseInsensitiveStringMap): CaseInsensitiveStringMap = { + val finalOptions = this.options.asCaseSensitiveMap().asScala ++ + options.asCaseSensitiveMap().asScala + new CaseInsensitiveStringMap(finalOptions.asJava) + } } object FileTable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala index 8b4fd3af6ded7..4c201ca66cf6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala @@ -38,7 +38,7 @@ case class CSVTable( fallbackFileFormat: Class[_ <: FileFormat]) extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { override def newScanBuilder(options: CaseInsensitiveStringMap): CSVScanBuilder = - CSVScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + CSVScanBuilder(sparkSession, fileIndex, schema, dataSchema, mergedOptions(options)) override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { val parsedOptions = new CSVOptions( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala index c567e87e7d767..54244c4d95e77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala @@ -38,7 +38,7 @@ case class JsonTable( fallbackFileFormat: Class[_ <: FileFormat]) extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { override def newScanBuilder(options: CaseInsensitiveStringMap): JsonScanBuilder = - new JsonScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + JsonScanBuilder(sparkSession, fileIndex, schema, dataSchema, mergedOptions(options)) override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { val parsedOptions = new JSONOptionsInRead( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala index ca4b83b3c58f1..1037370967c87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala @@ -38,7 +38,7 @@ case class OrcTable( extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { override def newScanBuilder(options: CaseInsensitiveStringMap): OrcScanBuilder = - new OrcScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + OrcScanBuilder(sparkSession, fileIndex, schema, dataSchema, mergedOptions(options)) override def inferSchema(files: Seq[FileStatus]): Option[StructType] = OrcUtils.inferSchema(sparkSession, files, options.asScala.toMap) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala index e593ad7d0c0cd..8463a05569c05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala @@ -38,7 +38,7 @@ case class ParquetTable( extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { override def newScanBuilder(options: CaseInsensitiveStringMap): ParquetScanBuilder = - new ParquetScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + ParquetScanBuilder(sparkSession, fileIndex, schema, dataSchema, mergedOptions(options)) override def inferSchema(files: Seq[FileStatus]): Option[StructType] = ParquetUtils.inferSchema(sparkSession, options.asScala.toMap, files) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 83399e2cac01b..50b90641d309b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging -import org.apache.spark.sql.{RuntimeConfig, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.DataSourceOptions import org.apache.spark.sql.connector.catalog.{Table, TableProvider} import org.apache.spark.sql.connector.expressions.Transform @@ -119,9 +119,9 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging throw StateDataSourceErrors.offsetMetadataLogUnavailable(batchId, checkpointLocation) ) - val clonedRuntimeConf = new RuntimeConfig(session.sessionState.conf.clone()) - OffsetSeqMetadata.setSessionConf(metadata, clonedRuntimeConf) - StateStoreConf(clonedRuntimeConf.sqlConf) + val clonedSqlConf = session.sessionState.conf.clone() + OffsetSeqMetadata.setSessionConf(metadata, clonedSqlConf) + StateStoreConf(clonedSqlConf) case _ => throw StateDataSourceErrors.offsetLogUnavailable(batchId, checkpointLocation) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 53576c335cb01..24166a46bbd39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.execution.datasources.v2.state import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil import org.apache.spark.sql.execution.streaming.{StateVariableType, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{NullType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{NextIterator, SerializableConfiguration} @@ -68,10 +69,23 @@ abstract class StatePartitionReaderBase( stateVariableInfoOpt: Option[TransformWithStateVariableInfo], stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]) extends PartitionReader[InternalRow] with Logging { - protected val keySchema = SchemaUtil.getSchemaAsDataType( - schema, "key").asInstanceOf[StructType] - protected val valueSchema = SchemaUtil.getSchemaAsDataType( - schema, "value").asInstanceOf[StructType] + // Used primarily as a placeholder for the value schema in the context of + // state variables used within the transformWithState operator. + private val schemaForValueRow: StructType = + StructType(Array(StructField("__dummy__", NullType))) + + protected val keySchema = { + if (!SchemaUtil.isMapStateVariable(stateVariableInfoOpt)) { + SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] + } else SchemaUtil.getCompositeKeySchema(schema) + } + + protected val valueSchema = if (stateVariableInfoOpt.isDefined) { + schemaForValueRow + } else { + SchemaUtil.getSchemaAsDataType( + schema, "value").asInstanceOf[StructType] + } protected lazy val provider: StateStoreProvider = { val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString, @@ -84,10 +98,17 @@ abstract class StatePartitionReaderBase( false } + val useMultipleValuesPerKey = if (stateVariableInfoOpt.isDefined && + stateVariableInfoOpt.get.stateVariableType == StateVariableType.ListState) { + true + } else { + false + } + val provider = StateStoreProvider.createAndInit( stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec, useColumnFamilies = useColFamilies, storeConf, hadoopConf.value, - useMultipleValuesPerKey = false) + useMultipleValuesPerKey = useMultipleValuesPerKey) if (useColFamilies) { val store = provider.getStore(partition.sourceOptions.batchId + 1) @@ -99,7 +120,7 @@ abstract class StatePartitionReaderBase( stateStoreColFamilySchema.keySchema, stateStoreColFamilySchema.valueSchema, stateStoreColFamilySchema.keyStateEncoderSpec.get, - useMultipleValuesPerKey = false) + useMultipleValuesPerKey = useMultipleValuesPerKey) } provider } @@ -160,32 +181,43 @@ class StatePartitionReader( override lazy val iter: Iterator[InternalRow] = { val stateVarName = stateVariableInfoOpt .map(_.stateName).getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME) - store - .iterator(stateVarName) - .map { pair => - stateVariableInfoOpt match { - case Some(stateVarInfo) => - val stateVarType = stateVarInfo.stateVariableType - val hasTTLEnabled = stateVarInfo.ttlEnabled - - stateVarType match { - case StateVariableType.ValueState => - if (hasTTLEnabled) { - SchemaUtil.unifyStateRowPairWithTTL((pair.key, pair.value), valueSchema, - partition.partition) - } else { + if (SchemaUtil.isMapStateVariable(stateVariableInfoOpt)) { + SchemaUtil.unifyMapStateRowPair( + store.iterator(stateVarName), keySchema, partition.partition) + } else { + store + .iterator(stateVarName) + .map { pair => + stateVariableInfoOpt match { + case Some(stateVarInfo) => + val stateVarType = stateVarInfo.stateVariableType + + stateVarType match { + case StateVariableType.ValueState => SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition) - } - case _ => - throw new IllegalStateException( - s"Unsupported state variable type: $stateVarType") - } + case StateVariableType.ListState => + val key = pair.key + val result = store.valuesIterator(key, stateVarName) + var unsafeRowArr: Seq[UnsafeRow] = Seq.empty + result.foreach { entry => + unsafeRowArr = unsafeRowArr :+ entry.copy() + } + // convert the list of values to array type + val arrData = new GenericArrayData(unsafeRowArr.toArray) + SchemaUtil.unifyStateRowPairWithMultipleValues((pair.key, arrData), + partition.partition) + + case _ => + throw new IllegalStateException( + s"Unsupported state variable type: $stateVarType") + } - case None => - SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition) + case None => + SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition) + } } - } + } } override def close(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala index 9dd357530ec40..88ea06d598e56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala @@ -16,13 +16,18 @@ */ package org.apache.spark.sql.execution.datasources.v2.state.utils +import scala.collection.mutable +import scala.util.control.NonFatal + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceErrors, StateSourceOptions} -import org.apache.spark.sql.execution.streaming.{StateVariableType, TransformWithStateVariableInfo} -import org.apache.spark.sql.execution.streaming.state.StateStoreColFamilySchema -import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.execution.streaming.StateVariableType._ +import org.apache.spark.sql.execution.streaming.TransformWithStateVariableInfo +import org.apache.spark.sql.execution.streaming.state.{StateStoreColFamilySchema, UnsafeRowPair} +import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, LongType, MapType, StringType, StructType} import org.apache.spark.util.ArrayImplicits._ object SchemaUtil { @@ -70,18 +75,122 @@ object SchemaUtil { row } - def unifyStateRowPairWithTTL( - pair: (UnsafeRow, UnsafeRow), - valueSchema: StructType, + def unifyStateRowPairWithMultipleValues( + pair: (UnsafeRow, GenericArrayData), partition: Int): InternalRow = { - val row = new GenericInternalRow(4) + val row = new GenericInternalRow(3) row.update(0, pair._1) - row.update(1, pair._2.get(0, valueSchema)) - row.update(2, pair._2.get(1, LongType)) - row.update(3, partition) + row.update(1, pair._2) + row.update(2, partition) row } + /** + * For map state variables, state rows are stored as composite key. + * To return grouping key -> Map{user key -> value} as one state reader row to + * the users, we need to perform grouping on state rows by their grouping key, + * and construct a map for that grouping key. + * + * We traverse the iterator returned from state store, + * and will only return a row for `next()` only if the grouping key in the next row + * from state store is different (or there are no more rows) + * + * Note that all state rows with the same grouping key are co-located so they will + * appear consecutively during the iterator traversal. + */ + def unifyMapStateRowPair( + stateRows: Iterator[UnsafeRowPair], + compositeKeySchema: StructType, + partitionId: Int): Iterator[InternalRow] = { + val groupingKeySchema = SchemaUtil.getSchemaAsDataType( + compositeKeySchema, "key" + ).asInstanceOf[StructType] + val userKeySchema = SchemaUtil.getSchemaAsDataType( + compositeKeySchema, "userKey" + ).asInstanceOf[StructType] + + def appendKVPairToMap( + curMap: mutable.Map[Any, Any], + stateRowPair: UnsafeRowPair): Unit = { + curMap += ( + stateRowPair.key.get(1, userKeySchema) + .asInstanceOf[UnsafeRow].copy() -> + stateRowPair.value.copy() + ) + } + + def createDataRow( + groupingKey: Any, + curMap: mutable.Map[Any, Any]): GenericInternalRow = { + val row = new GenericInternalRow(3) + val mapData = ArrayBasedMapData(curMap) + row.update(0, groupingKey) + row.update(1, mapData) + row.update(2, partitionId) + row + } + + // All of the rows with the same grouping key were co-located and were + // grouped together consecutively. + new Iterator[InternalRow] { + var curGroupingKey: UnsafeRow = _ + var curStateRowPair: UnsafeRowPair = _ + val curMap = mutable.Map.empty[Any, Any] + + override def hasNext: Boolean = + stateRows.hasNext || !curMap.isEmpty + + override def next(): InternalRow = { + var foundNewGroupingKey = false + while (stateRows.hasNext && !foundNewGroupingKey) { + curStateRowPair = stateRows.next() + if (curGroupingKey == null) { + // First time in the iterator + // Need to make a copy because we need to keep the + // value across function calls + curGroupingKey = curStateRowPair.key + .get(0, groupingKeySchema).asInstanceOf[UnsafeRow].copy() + appendKVPairToMap(curMap, curStateRowPair) + } else { + val curPairGroupingKey = + curStateRowPair.key.get(0, groupingKeySchema) + if (curPairGroupingKey == curGroupingKey) { + appendKVPairToMap(curMap, curStateRowPair) + } else { + // find a different grouping key, exit loop and return a row + foundNewGroupingKey = true + } + } + } + if (foundNewGroupingKey) { + // found a different grouping key + val row = createDataRow(curGroupingKey, curMap) + // update vars + curGroupingKey = + curStateRowPair.key.get(0, groupingKeySchema) + .asInstanceOf[UnsafeRow].copy() + // empty the map, append current row + curMap.clear() + appendKVPairToMap(curMap, curStateRowPair) + // return map value of previous grouping key + row + } else { + if (curMap.isEmpty) { + throw new NoSuchElementException("Please check if the iterator hasNext(); Likely " + + "user is trying to get element from an exhausted iterator.") + } + else { + // reach the end of the state rows + val row = createDataRow(curGroupingKey, curMap) + // clear the map to end the iterator + curMap.clear() + row + } + } + } + } + } + def isValidSchema( sourceOptions: StateSourceOptions, schema: StructType, @@ -91,23 +200,26 @@ object SchemaUtil { "change_type" -> classOf[StringType], "key" -> classOf[StructType], "value" -> classOf[StructType], - "partition_id" -> classOf[IntegerType], - "expiration_timestamp" -> classOf[LongType]) + "single_value" -> classOf[StructType], + "list_value" -> classOf[ArrayType], + "map_value" -> classOf[MapType], + "partition_id" -> classOf[IntegerType]) val expectedFieldNames = if (sourceOptions.readChangeFeed) { Seq("batch_id", "change_type", "key", "value", "partition_id") } else if (transformWithStateVariableInfoOpt.isDefined) { val stateVarInfo = transformWithStateVariableInfoOpt.get - val hasTTLEnabled = stateVarInfo.ttlEnabled val stateVarType = stateVarInfo.stateVariableType stateVarType match { - case StateVariableType.ValueState => - if (hasTTLEnabled) { - Seq("key", "value", "expiration_timestamp", "partition_id") - } else { - Seq("key", "value", "partition_id") - } + case ValueState => + Seq("key", "single_value", "partition_id") + + case ListState => + Seq("key", "list_value", "partition_id") + + case MapState => + Seq("key", "map_value", "partition_id") case _ => throw StateDataSourceErrors @@ -131,27 +243,73 @@ object SchemaUtil { stateVarInfo: TransformWithStateVariableInfo, stateStoreColFamilySchema: StateStoreColFamilySchema): StructType = { val stateVarType = stateVarInfo.stateVariableType - val hasTTLEnabled = stateVarInfo.ttlEnabled stateVarType match { - case StateVariableType.ValueState => - if (hasTTLEnabled) { - val ttlValueSchema = SchemaUtil.getSchemaAsDataType( - stateStoreColFamilySchema.valueSchema, "value").asInstanceOf[StructType] - new StructType() - .add("key", stateStoreColFamilySchema.keySchema) - .add("value", ttlValueSchema) - .add("expiration_timestamp", LongType) - .add("partition_id", IntegerType) - } else { - new StructType() - .add("key", stateStoreColFamilySchema.keySchema) - .add("value", stateStoreColFamilySchema.valueSchema) - .add("partition_id", IntegerType) - } + case ValueState => + new StructType() + .add("key", stateStoreColFamilySchema.keySchema) + .add("single_value", stateStoreColFamilySchema.valueSchema) + .add("partition_id", IntegerType) + + case ListState => + new StructType() + .add("key", stateStoreColFamilySchema.keySchema) + .add("list_value", ArrayType(stateStoreColFamilySchema.valueSchema)) + .add("partition_id", IntegerType) + + case MapState => + val groupingKeySchema = SchemaUtil.getSchemaAsDataType( + stateStoreColFamilySchema.keySchema, "key") + val userKeySchema = stateStoreColFamilySchema.userKeyEncoderSchema.get + val valueMapSchema = MapType.apply( + keyType = userKeySchema, + valueType = stateStoreColFamilySchema.valueSchema + ) + + new StructType() + .add("key", groupingKeySchema) + .add("map_value", valueMapSchema) + .add("partition_id", IntegerType) case _ => throw StateDataSourceErrors.internalError(s"Unsupported state variable type $stateVarType") } } + + /** + * Helper functions for map state data source reader. + * + * Map state variables are stored in RocksDB state store has the schema of + * `TransformWithStateKeyValueRowSchemaUtils.getCompositeKeySchema()`; + * But for state store reader, we need to return in format of: + * "key": groupingKey, "map_value": Map(userKey -> value). + * + * The following functions help to translate between two schema. + */ + def isMapStateVariable( + stateVariableInfoOpt: Option[TransformWithStateVariableInfo]): Boolean = { + stateVariableInfoOpt.isDefined && + stateVariableInfoOpt.get.stateVariableType == MapState + } + + /** + * Given key-value schema generated from `generateSchemaForStateVar()`, + * returns the compositeKey schema that key is stored in the state store + */ + def getCompositeKeySchema(schema: StructType): StructType = { + val groupingKeySchema = SchemaUtil.getSchemaAsDataType( + schema, "key").asInstanceOf[StructType] + val userKeySchema = try { + Option( + SchemaUtil.getSchemaAsDataType(schema, "map_value").asInstanceOf[MapType] + .keyType.asInstanceOf[StructType]) + } catch { + case NonFatal(e) => + throw StateDataSourceErrors.internalError(s"No such field named as 'map_value' " + + s"during state source reader schema initialization. Internal exception message: $e") + } + new StructType() + .add("key", groupingKeySchema) + .add("userKey", userKeySchema.get) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala index 046bdcb69846e..87ae34532f88a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala @@ -34,7 +34,7 @@ case class TextTable( fallbackFileFormat: Class[_ <: FileFormat]) extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { override def newScanBuilder(options: CaseInsensitiveStringMap): TextScanBuilder = - TextScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + TextScanBuilder(sparkSession, fileIndex, schema, dataSchema, mergedOptions(options)) override def inferSchema(files: Seq[FileStatus]): Option[StructType] = Some(StructType(Array(StructField("value", StringType)))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInArrowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInArrowExec.scala index e91140414732b..a2d200dc86e18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInArrowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInArrowExec.scala @@ -23,13 +23,13 @@ import org.apache.spark.sql.execution.SparkPlan /** - * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapCoGroupsInPandas]] + * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapCoGroupsInArrow]] * * The input dataframes are first Cogrouped. Rows from each side of the cogroup are passed to the * Python worker via Arrow. As each side of the cogroup may have a different schema we send every * group in its own Arrow stream. - * The Python worker turns the resulting record batches to `pandas.DataFrame`s, invokes the - * user-defined function, and passes the resulting `pandas.DataFrame` + * The Python worker turns the resulting record batches to `pyarrow.Table`s, invokes the + * user-defined function, and passes the resulting `pyarrow.Table` * as an Arrow record batch. Finally, each record batch is turned to * Iterator[InternalRow] using ColumnarBatch. * @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.SparkPlan * Both the Python worker and the Java executor need to have enough memory to * hold the largest cogroup. The memory on the Java side is used to construct the * record batches (off heap memory). The memory on the Python side is used for - * holding the `pandas.DataFrame`. It's possible to further split one group into + * holding the `pyarrow.Table`. It's possible to further split one group into * multiple record batches to reduce the memory footprint on the Java side, this * is left as future work. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInArrowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInArrowExec.scala index 942aaf6e44c17..6569b29f3954f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInArrowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInArrowExec.scala @@ -25,11 +25,11 @@ import org.apache.spark.sql.types.{StructField, StructType} /** - * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]] + * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInArrow]] * * Rows in each group are passed to the Python worker as an Arrow record batch. - * The Python worker turns the record batch to a `pandas.DataFrame`, invoke the - * user-defined function, and passes the resulting `pandas.DataFrame` + * The Python worker turns the record batch to a `pyarrow.Table`, invokes the + * user-defined function, and passes the resulting `pyarrow.Table` * as an Arrow record batch. Finally, each record batch is turned to * Iterator[InternalRow] using ColumnarBatch. * @@ -37,7 +37,7 @@ import org.apache.spark.sql.types.{StructField, StructType} * Both the Python worker and the Java executor need to have enough memory to * hold the largest group. The memory on the Java side is used to construct the * record batch (off heap memory). The memory on the Python side is used for - * holding the `pandas.DataFrame`. It's possible to further split one group into + * holding the `pyarrow.Table`. It's possible to further split one group into * multiple record batches to reduce the memory footprint on the Java side, this * is left as future work. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala index 67b264436fea9..ed7ff6a753487 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -21,12 +21,15 @@ import java.io.File import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock -import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext} +import scala.util.control.NonFatal + +import org.apache.spark.{JobArtifactSet, SparkEnv, SparkThrowable, TaskContext} import org.apache.spark.api.python._ import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.sources.ForeachUserFuncException import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.util.{NextIterator, Utils} @@ -53,6 +56,8 @@ class WriterThread(outputIterator: Iterator[Array[Byte]]) } catch { // Cache exceptions seen while evaluating the Python function on the streamed input. The // parent thread will throw this crashed exception eventually. + case NonFatal(e) if !e.isInstanceOf[SparkThrowable] => + _exception = ForeachUserFuncException(e) case t: Throwable => _exception = t } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala index 33612b6947f27..11ab706e8abb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala @@ -233,7 +233,17 @@ class PythonStreamingSourceRunner( s"stream reader for $pythonExec", 0, Long.MaxValue) def readArrowRecordBatches(): Iterator[InternalRow] = { - assert(dataIn.readInt() == SpecialLengths.START_ARROW_STREAM) + val status = dataIn.readInt() + status match { + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + val msg = PythonWorkerUtils.readUTF(dataIn) + throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError( + action = "prefetchArrowBatches", msg) + case SpecialLengths.START_ARROW_STREAM => + case _ => + throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError( + action = "prefetchArrowBatches", s"unknown status code $status") + } val reader = new ArrowStreamReader(dataIn, allocator) val root = reader.getVectorSchemaRoot() // When input is empty schema can't be read. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala index 35275eb16ebab..f6a3f1b1394f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala @@ -133,7 +133,7 @@ case class TransformWithStateInPandasExec( val data = groupAndProject(dataIterator, groupingAttributes, child.output, dedupAttributes) val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId, - groupingKeyExprEncoder, timeMode) + groupingKeyExprEncoder, timeMode, isStreaming = true, batchTimestampMs, metrics) val runner = new TransformWithStateInPandasPythonRunner( chainedFunc, PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala index d23f1dbd422a7..b5ec26b401d28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException} import java.net.ServerSocket +import java.time.Duration import scala.collection.mutable @@ -30,7 +31,7 @@ import org.apache.spark.sql.api.python.PythonSQLUtils import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, StateVariableRequest, ValueStateCall} -import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.streaming.{TTLConfig, ValueState} import org.apache.spark.sql.types.StructType /** @@ -153,7 +154,10 @@ class TransformWithStateInPandasStateServer( case StatefulProcessorCall.MethodCase.GETVALUESTATE => val stateName = message.getGetValueState.getStateName val schema = message.getGetValueState.getSchema - initializeValueState(stateName, schema) + val ttlDurationMs = if (message.getGetValueState.hasTtl) { + Some(message.getGetValueState.getTtl.getDurationMs) + } else None + initializeValueState(stateName, schema, ttlDurationMs) case _ => throw new IllegalArgumentException("Invalid method call") } @@ -228,10 +232,18 @@ class TransformWithStateInPandasStateServer( outputStream.write(responseMessageBytes) } - private def initializeValueState(stateName: String, schemaString: String): Unit = { + private def initializeValueState( + stateName: String, + schemaString: String, + ttlDurationMs: Option[Int]): Unit = { if (!valueStates.contains(stateName)) { val schema = StructType.fromString(schemaString) - val state = statefulProcessorHandle.getValueState[Row](stateName, Encoders.row(schema)) + val state = if (ttlDurationMs.isEmpty) { + statefulProcessorHandle.getValueState[Row](stateName, Encoders.row(schema)) + } else { + statefulProcessorHandle.getValueState( + stateName, Encoders.row(schema), TTLConfig(Duration.ofMillis(ttlDurationMs.get))) + } val valueRowDeserializer = ExpressionEncoder(schema).resolveAndBind().createDeserializer() valueStates.put(stateName, (state, schema, valueRowDeserializer)) sendResponse(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala index aa393211a1c15..cb7e71bda84dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.ThreadUtils @@ -40,9 +41,16 @@ trait AsyncLogPurge extends Logging { private val purgeRunning = new AtomicBoolean(false) + private val statefulMetadataPurgeRunning = new AtomicBoolean(false) + protected def purge(threshold: Long): Unit - protected lazy val useAsyncPurge: Boolean = sparkSession.conf.get(SQLConf.ASYNC_LOG_PURGE) + // This method is used to purge the oldest OperatorStateMetadata and StateSchema files + // which are written per run. + protected def purgeStatefulMetadata(plan: SparkPlan): Unit + + protected lazy val useAsyncPurge: Boolean = sparkSession.sessionState.conf + .getConf(SQLConf.ASYNC_LOG_PURGE) protected def purgeAsync(batchId: Long): Unit = { if (purgeRunning.compareAndSet(false, true)) { @@ -62,6 +70,24 @@ trait AsyncLogPurge extends Logging { } } + protected def purgeStatefulMetadataAsync(plan: SparkPlan): Unit = { + if (statefulMetadataPurgeRunning.compareAndSet(false, true)) { + asyncPurgeExecutorService.execute(() => { + try { + purgeStatefulMetadata(plan) + } catch { + case throwable: Throwable => + logError("Encountered error while performing async log purge", throwable) + errorNotifier.markError(throwable) + } finally { + statefulMetadataPurgeRunning.set(false) + } + }) + } else { + log.debug("Skipped log purging since there is already one in progress.") + } + } + protected def asyncLogPurgeShutdown(): Unit = { ThreadUtils.shutdown(asyncPurgeExecutorService) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index d56dfebd61ba1..766caaab2285e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -18,8 +18,11 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.TimeUnit.NANOSECONDS +import scala.util.control.NonFatal + import org.apache.hadoop.conf.Configuration +import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -447,10 +450,33 @@ case class FlatMapGroupsWithStateExec( hasTimedOut, watermarkPresent) - // Call function, get the returned objects and convert them to rows - val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj => - numOutputRows += 1 - getOutputRow(obj) + def withUserFuncExceptionHandling[T](func: => T): T = { + try { + func + } catch { + case NonFatal(e) if !e.isInstanceOf[SparkThrowable] => + throw FlatMapGroupsWithStateUserFuncException(e) + case f: Throwable => + throw f + } + } + + val mappedIterator = withUserFuncExceptionHandling { + func(keyObj, valueObjIter, groupState).map { obj => + numOutputRows += 1 + getOutputRow(obj) + } + } + + // Wrap user-provided fns with error handling + val wrappedMappedIterator = new Iterator[InternalRow] { + override def hasNext: Boolean = { + withUserFuncExceptionHandling(mappedIterator.hasNext) + } + + override def next(): InternalRow = { + withUserFuncExceptionHandling(mappedIterator.next()) + } } // When the iterator is consumed, then write changes to state @@ -472,7 +498,9 @@ case class FlatMapGroupsWithStateExec( } // Return an iterator of rows such that fully consumed, the updated state value will be saved - CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) + CompletionIterator[InternalRow, Iterator[InternalRow]]( + wrappedMappedIterator, onIteratorCompletion + ) } } } @@ -544,3 +572,13 @@ object FlatMapGroupsWithStateExec { } } } + + +/** + * Exception that wraps the exception thrown in the user provided function in Foreach sink. + */ +private[sql] case class FlatMapGroupsWithStateUserFuncException(cause: Throwable) + extends SparkException( + errorClass = "FLATMAPGROUPSWITHSTATE_USER_FUNCTION_ERROR", + messageParameters = Map("reason" -> Option(cause.getMessage).getOrElse("")), + cause = cause) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index f59cdca8aefec..053aef6ced3a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -477,7 +477,7 @@ class MicroBatchExecution( // update offset metadata nextOffsets.metadata.foreach { metadata => - OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) + OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.sessionState.conf) execCtx.offsetSeqMetadata = OffsetSeqMetadata( metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) @@ -842,6 +842,10 @@ class MicroBatchExecution( markMicroBatchExecutionStart(execCtx) + if (execCtx.previousContext.isEmpty) { + purgeStatefulMetadataAsync(execCtx.executionPlan.executedPlan) + } + val nextBatch = new Dataset(execCtx.executionPlan, ExpressionEncoder(execCtx.executionPlan.analyzed.schema)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index f0be33ad9a9d8..e1e5b3a7ef88e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -26,6 +26,7 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, SparkDataStream} import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager, SymmetricHashJoinStateManager} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ @@ -100,7 +101,9 @@ object OffsetSeqMetadata extends Logging { SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY, FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION, STREAMING_JOIN_STATE_FORMAT_VERSION, STATE_STORE_COMPRESSION_CODEC, - STATE_STORE_ROCKSDB_FORMAT_VERSION, STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION) + STATE_STORE_ROCKSDB_FORMAT_VERSION, STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION, + PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN + ) /** * Default values of relevant configurations that are used for backward compatibility. @@ -121,7 +124,8 @@ object OffsetSeqMetadata extends Logging { STREAMING_JOIN_STATE_FORMAT_VERSION.key -> SymmetricHashJoinStateManager.legacyVersion.toString, STATE_STORE_COMPRESSION_CODEC.key -> CompressionCodec.LZ4, - STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "false" + STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "false", + PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "true" ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) @@ -135,20 +139,21 @@ object OffsetSeqMetadata extends Logging { } /** Set the SparkSession configuration with the values in the metadata */ - def setSessionConf(metadata: OffsetSeqMetadata, sessionConf: RuntimeConfig): Unit = { + def setSessionConf(metadata: OffsetSeqMetadata, sessionConf: SQLConf): Unit = { + val configs = sessionConf.getAllConfs OffsetSeqMetadata.relevantSQLConfs.map(_.key).foreach { confKey => metadata.conf.get(confKey) match { case Some(valueInMetadata) => // Config value exists in the metadata, update the session config with this value - val optionalValueInSession = sessionConf.getOption(confKey) - if (optionalValueInSession.isDefined && optionalValueInSession.get != valueInMetadata) { + val optionalValueInSession = sessionConf.getConfString(confKey, null) + if (optionalValueInSession != null && optionalValueInSession != valueInMetadata) { logWarning(log"Updating the value of conf '${MDC(CONFIG, confKey)}' in current " + - log"session from '${MDC(OLD_VALUE, optionalValueInSession.get)}' " + + log"session from '${MDC(OLD_VALUE, optionalValueInSession)}' " + log"to '${MDC(NEW_VALUE, valueInMetadata)}'.") } - sessionConf.set(confKey, valueInMetadata) + sessionConf.setConfString(confKey, valueInMetadata) case None => // For backward compatibility, if a config was not recorded in the offset log, @@ -157,14 +162,17 @@ object OffsetSeqMetadata extends Logging { relevantSQLConfDefaultValues.get(confKey) match { case Some(defaultValue) => - sessionConf.set(confKey, defaultValue) + sessionConf.setConfString(confKey, defaultValue) logWarning(log"Conf '${MDC(CONFIG, confKey)}' was not found in the offset log, " + log"using default value '${MDC(DEFAULT_VALUE, defaultValue)}'") case None => - val valueStr = sessionConf.getOption(confKey).map { v => - s" Using existing session conf value '$v'." - }.getOrElse { " No value set in session conf." } + val value = sessionConf.getConfString(confKey, null) + val valueStr = if (value != null) { + s" Using existing session conf value '$value'." + } else { + " No value set in session conf." + } logWarning(log"Conf '${MDC(CONFIG, confKey)}' was not found in the offset log. " + log"${MDC(TIP, valueStr)}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 6fd58e13366e0..8f030884ad33b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -41,8 +41,10 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table} import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit, SparkDataStream} import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsTruncate, Write} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.StreamingExplainCommand -import org.apache.spark.sql.execution.streaming.sources.ForeachBatchUserFuncException +import org.apache.spark.sql.execution.streaming.sources.{ForeachBatchUserFuncException, ForeachUserFuncException} +import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataV2FileManager import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend import org.apache.spark.sql.streaming._ @@ -346,9 +348,11 @@ abstract class StreamExecution( getLatestExecutionContext().updateStatusMessage("Stopped") case e: Throwable => val message = if (e.getMessage == null) "" else e.getMessage - val cause = if (e.isInstanceOf[ForeachBatchUserFuncException]) { + val cause = if (e.isInstanceOf[ForeachBatchUserFuncException] || + e.isInstanceOf[ForeachUserFuncException]) { // We want to maintain the current way users get the causing exception - // from the StreamingQueryException. Hence the ForeachBatch exception is unwrapped here. + // from the StreamingQueryException. + // Hence the ForeachBatch/Foreach exception is unwrapped here. e.getCause } else { e @@ -367,13 +371,7 @@ abstract class StreamExecution( messageParameters = Map( "id" -> id.toString, "runId" -> runId.toString, - "message" -> message, - "queryDebugString" -> toDebugString(includeLogicalPlan = isInitialized), - "startOffset" -> getLatestExecutionContext().startOffsets.toOffsetSeq( - sources.toSeq, getLatestExecutionContext().offsetSeqMetadata).toString, - "endOffset" -> getLatestExecutionContext().endOffsets.toOffsetSeq( - sources.toSeq, getLatestExecutionContext().offsetSeqMetadata).toString - )) + "message" -> message)) errorClassOpt = e match { case t: SparkThrowable => Option(t.getErrorClass) @@ -485,7 +483,7 @@ abstract class StreamExecution( @throws[TimeoutException] protected def interruptAndAwaitExecutionThreadTermination(): Unit = { val timeout = math.max( - sparkSession.conf.get(SQLConf.STREAMING_STOP_TIMEOUT), 0) + sparkSession.sessionState.conf.getConf(SQLConf.STREAMING_STOP_TIMEOUT), 0) queryExecutionThread.interrupt() queryExecutionThread.join(timeout) if (queryExecutionThread.isAlive) { @@ -690,6 +688,31 @@ abstract class StreamExecution( offsetLog.purge(threshold) commitLog.purge(threshold) } + + protected def purgeStatefulMetadata(plan: SparkPlan): Unit = { + plan.collect { case statefulOperator: StatefulOperator => + statefulOperator match { + case ssw: StateStoreWriter => + ssw.operatorStateMetadataVersion match { + case 2 => + // checkpointLocation of the operator is runId/state, and commitLog path is + // runId/commits, so we want the parent of the checkpointLocation to get the + // commit log path. + val parentCheckpointLocation = + new Path(statefulOperator.getStateInfo.checkpointLocation).getParent + + val fileManager = new OperatorStateMetadataV2FileManager( + parentCheckpointLocation, + sparkSession, + ssw + ) + fileManager.purgeMetadataFiles() + case _ => + } + case _ => + } + } + } } object StreamExecution { @@ -728,6 +751,7 @@ object StreamExecution { if e2.getCause != null => isInterruptionException(e2.getCause, sc) case fe: ForeachBatchUserFuncException => isInterruptionException(fe.getCause, sc) + case fes: ForeachUserFuncException => isInterruptionException(fes.getCause, sc) case se: SparkException => if (se.getCause == null) { isCancelledJobGroup(se.getMessage) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala index 6065af10ffe80..8811c59a50745 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala @@ -85,7 +85,7 @@ abstract class SingleKeyTTLStateImpl( import org.apache.spark.sql.execution.streaming.StateTTLSchema._ - private val ttlColumnFamilyName = s"_ttl_$stateName" + private val ttlColumnFamilyName = "$ttl_" + stateName private val keySchema = getSingleKeyTTLRowSchema(keyExprEnc.schema) private val keyTTLRowEncoder = new SingleKeyTTLEncoder(keyExprEnc) @@ -205,7 +205,7 @@ abstract class CompositeKeyTTLStateImpl[K]( import org.apache.spark.sql.execution.streaming.StateTTLSchema._ - private val ttlColumnFamilyName = s"_ttl_$stateName" + private val ttlColumnFamilyName = "$ttl_" + stateName private val keySchema = getCompositeKeyTTLRowSchema( keyExprEnc.schema, userKeyEncoder.schema ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala index 650f57039a030..82a4226fcfd54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala @@ -30,8 +30,8 @@ import org.apache.spark.util.NextIterator * Singleton utils class used primarily while interacting with TimerState */ object TimerStateUtils { - val PROC_TIMERS_STATE_NAME = "_procTimers" - val EVENT_TIMERS_STATE_NAME = "_eventTimers" + val PROC_TIMERS_STATE_NAME = "$procTimers" + val EVENT_TIMERS_STATE_NAME = "$eventTimers" val KEY_TO_TIMESTAMP_CF = "_keyToTimestamp" val TIMESTAMP_TO_KEY_CF = "_timestampToKey" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala index 54c47ec4e6ed8..3e6f122f463d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -135,7 +135,8 @@ object WatermarkTracker { // saved in the checkpoint (e.g., old checkpoints), then the default `min` policy is enforced // through defaults specified in OffsetSeqMetadata.setSessionConf(). val policyName = conf.get( - SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY, MultipleWatermarkPolicy.DEFAULT_POLICY_NAME) + SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key, + MultipleWatermarkPolicy.DEFAULT_POLICY_NAME) new WatermarkTracker(MultipleWatermarkPolicy(policyName)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala index 420c3e3be16d6..273ffa6aefb7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala @@ -32,8 +32,9 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{HOST, PORT} import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReader, ContinuousPartitionReaderFactory, ContinuousStream, Offset, PartitionOffset} @@ -57,8 +58,7 @@ class TextSocketContinuousStream( implicit val defaultFormats: DefaultFormats = DefaultFormats - private val encoder = ExpressionEncoder.tuple(ExpressionEncoder[String](), - ExpressionEncoder[Timestamp]()) + private val encoder = encoderFor(Encoders.tuple(Encoders.STRING, Encoders.TIMESTAMP)) @GuardedBy("this") private var socket: Socket = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index bbbe28ec7ab11..c0956a62e59fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql.execution.streaming.sources import java.util +import scala.util.control.NonFatal + +import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.sql.{ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -146,6 +149,9 @@ class ForeachDataWriter[T]( try { writer.process(rowConverter(record)) } catch { + case NonFatal(e) if !e.isInstanceOf[SparkThrowable] => + errorOrNull = e + throw ForeachUserFuncException(e) case t: Throwable => errorOrNull = t throw t @@ -172,3 +178,12 @@ class ForeachDataWriter[T]( * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination. */ case object ForeachWriterCommitMessage extends WriterCommitMessage + +/** + * Exception that wraps the exception thrown in the user provided function in Foreach sink. + */ +private[sql] case class ForeachUserFuncException(cause: Throwable) + extends SparkException( + errorClass = "FOREACH_USER_FUNCTION_ERROR", + messageParameters = Map("reason" -> Option(cause.getMessage).getOrElse("")), + cause = cause) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala index 3e68b3975e662..aa2f332afeff4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala @@ -23,14 +23,14 @@ import java.nio.charset.StandardCharsets import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream, Path} +import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream, Path, PathFilter} import org.json4s.{Formats, NoTypeHints} import org.json4s.jackson.Serialization import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceErrors -import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, CommitLog, MetadataVersionUtil, OffsetSeqLog} +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, CommitLog, MetadataVersionUtil, OffsetSeqLog, StateStoreWriter} import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS} import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataUtils.{OperatorStateMetadataReader, OperatorStateMetadataWriter} @@ -358,3 +358,121 @@ class OperatorStateMetadataV2Reader( } } } + +/** + * A helper class to manage the metadata files for the operator state checkpoint. + * This class is used to manage the metadata files for OperatorStateMetadataV2, and + * provides utils to purge the oldest files such that we only keep the metadata files + * for which a commit log is present + * @param checkpointLocation The root path of the checkpoint directory + * @param sparkSession The sparkSession that is used to access the hadoopConf + * @param stateStoreWriter The operator that this fileManager is being created for + */ +class OperatorStateMetadataV2FileManager( + checkpointLocation: Path, + sparkSession: SparkSession, + stateStoreWriter: StateStoreWriter) extends Logging { + + private val hadoopConf = sparkSession.sessionState.newHadoopConf() + private val stateCheckpointPath = new Path(checkpointLocation, "state") + private val stateOpIdPath = new Path( + stateCheckpointPath, stateStoreWriter.getStateInfo.operatorId.toString) + private val commitLog = + new CommitLog(sparkSession, new Path(checkpointLocation, "commits").toString) + private val stateSchemaPath = stateStoreWriter.stateSchemaDirPath() + private val metadataDirPath = OperatorStateMetadataV2.metadataDirPath(stateOpIdPath) + private lazy val fm = CheckpointFileManager.create(metadataDirPath, hadoopConf) + + protected def isBatchFile(path: Path): Boolean = { + try { + path.getName.toLong + true + } catch { + case _: NumberFormatException => false + } + } + + /** + * A `PathFilter` to filter only batch files + */ + protected val batchFilesFilter: PathFilter = (path: Path) => isBatchFile(path) + + private def pathToBatchId(path: Path): Long = { + path.getName.toLong + } + + def purgeMetadataFiles(): Unit = { + val thresholdBatchId = findThresholdBatchId() + if (thresholdBatchId != 0) { + val earliestBatchIdKept = deleteMetadataFiles(thresholdBatchId) + // we need to delete everything from 0 to (earliestBatchIdKept - 1), inclusive + deleteSchemaFiles(earliestBatchIdKept - 1) + } + } + + // We only want to keep the metadata and schema files for which the commit + // log is present, so we will delete any file that precedes the batch for the oldest + // commit log + private def findThresholdBatchId(): Long = { + commitLog.listBatchesOnDisk.headOption.getOrElse(0L) + } + + private def deleteSchemaFiles(thresholdBatchId: Long): Unit = { + val schemaFiles = fm.list(stateSchemaPath).sorted.map(_.getPath) + val filesBeforeThreshold = schemaFiles.filter { path => + val batchIdInPath = path.getName.split("_").head.toLong + batchIdInPath <= thresholdBatchId + } + filesBeforeThreshold.foreach { path => + fm.delete(path) + } + } + + // Deletes all metadata files that are below a thresholdBatchId, except + // for the latest metadata file so that we have at least 1 metadata and schema + // file at all times per each stateful query + // Returns the batchId of the earliest schema file we want to keep + private def deleteMetadataFiles(thresholdBatchId: Long): Long = { + val metadataFiles = fm.list(metadataDirPath, batchFilesFilter) + + if (metadataFiles.isEmpty) { + return -1L // No files to delete + } + + // get all the metadata files for which we don't have commit logs + val sortedBatchIds = metadataFiles + .map(file => pathToBatchId(file.getPath)) + .filter(_ <= thresholdBatchId) + .sorted + + if (sortedBatchIds.isEmpty) { + return -1L + } + + // we don't want to delete the batchId right before the last one + val latestBatchId = sortedBatchIds.last + + metadataFiles.foreach { batchFile => + val batchId = pathToBatchId(batchFile.getPath) + if (batchId < latestBatchId) { + fm.delete(batchFile.getPath) + } + } + val latestMetadata = OperatorStateMetadataReader.createReader( + stateOpIdPath, + hadoopConf, + 2, + latestBatchId + ).read() + + // find the batchId of the earliest schema file we need to keep + val earliestBatchToKeep = latestMetadata match { + case Some(OperatorStateMetadataV2(_, stateStoreInfo, _)) => + val schemaFilePath = stateStoreInfo.head.stateSchemaFilePath + new Path(schemaFilePath).getName.split("_").head.toLong + case _ => 0 + } + + earliestBatchToKeep + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 7badc26bf0447..4a2aac43b3331 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -263,7 +263,7 @@ class RocksDB( logInfo(log"Loading ${MDC(LogKeys.VERSION_NUM, version)}") try { if (loadedVersion != version) { - closeDB() + closeDB(ignoreException = false) // deep copy is needed to avoid race condition // between maintenance and task threads fileManager.copyFileMapping() @@ -407,10 +407,12 @@ class RocksDB( * Replay change log from the loaded version to the target version. */ private def replayChangelog(endVersion: Long): Unit = { + logInfo(log"Replaying changelog from version " + + log"${MDC(LogKeys.LOADED_VERSION, loadedVersion)} -> " + + log"${MDC(LogKeys.END_VERSION, endVersion)}") for (v <- loadedVersion + 1 to endVersion) { - logInfo(log"replaying changelog from version " + - log"${MDC(LogKeys.LOADED_VERSION, loadedVersion)} -> " + - log"${MDC(LogKeys.END_VERSION, endVersion)}") + logInfo(log"Replaying changelog on version " + + log"${MDC(LogKeys.VERSION_NUM, v)}") var changelogReader: StateStoreChangelogReader = null try { changelogReader = fileManager.getChangelogReader(v, useColumnFamilies) @@ -644,15 +646,15 @@ class RocksDB( // is enabled. if (shouldForceSnapshot.get()) { uploadSnapshot() + shouldForceSnapshot.set(false) + } + + // ensure that changelog files are always written + try { + assert(changelogWriter.isDefined) + changelogWriter.foreach(_.commit()) + } finally { changelogWriter = None - changelogWriter.foreach(_.abort()) - } else { - try { - assert(changelogWriter.isDefined) - changelogWriter.foreach(_.commit()) - } finally { - changelogWriter = None - } } } else { assert(changelogWriter.isEmpty) @@ -927,10 +929,12 @@ class RocksDB( * @param opType - operation type releasing the lock */ private def release(opType: RocksDBOpType): Unit = acquireLock.synchronized { - logInfo(log"RocksDB instance was released by ${MDC(LogKeys.THREAD, acquiredThreadInfo)} " + - log"for opType=${MDC(LogKeys.OP_TYPE, opType.toString)}") - acquiredThreadInfo = null - acquireLock.notifyAll() + if (acquiredThreadInfo != null) { + logInfo(log"RocksDB instance was released by ${MDC(LogKeys.THREAD, + acquiredThreadInfo)} " + log"for opType=${MDC(LogKeys.OP_TYPE, opType.toString)}") + acquiredThreadInfo = null + acquireLock.notifyAll() + } } private def getDBProperty(property: String): Long = db.getProperty(property).toLong @@ -941,13 +945,17 @@ class RocksDB( logInfo(log"Opened DB with conf ${MDC(LogKeys.CONFIG, conf)}") } - private def closeDB(): Unit = { + private def closeDB(ignoreException: Boolean = true): Unit = { if (db != null) { - // Cancel and wait until all background work finishes db.cancelAllBackgroundWork(true) - // Close the DB instance - db.close() + if (ignoreException) { + // Close the DB instance + db.close() + } else { + // Close the DB instance and throw the exception if any + db.closeE() + } db = null } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 685cc9a1533e4..85f80ce9eb1ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -567,7 +567,7 @@ private[sql] class RocksDBStateStoreProvider } // if the column family is not internal and uses reserved characters, throw an exception - if (!isInternal && colFamilyName.charAt(0) == '_') { + if (!isInternal && colFamilyName.charAt(0) == '$') { throw StateStoreErrors.cannotCreateColumnFamilyWithReservedChars(colFamilyName) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 47b1cb90e00a8..721d72b6a0991 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StatefulOperatorStateInfo} import org.apache.spark.sql.execution.streaming.state.SchemaHelper.{SchemaReader, SchemaWriter} +import org.apache.spark.sql.execution.streaming.state.StateSchemaCompatibilityChecker.SCHEMA_FORMAT_V3 import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.{DataType, StructType} @@ -95,7 +96,7 @@ class StateSchemaCompatibilityChecker( stateStoreColFamilySchema: List[StateStoreColFamilySchema], stateSchemaVersion: Int): Unit = { // Ensure that schema file path is passed explicitly for schema version 3 - if (stateSchemaVersion == 3 && newSchemaFilePath.isEmpty) { + if (stateSchemaVersion == SCHEMA_FORMAT_V3 && newSchemaFilePath.isEmpty) { throw new IllegalStateException("Schema file path is required for schema version 3") } @@ -167,12 +168,12 @@ class StateSchemaCompatibilityChecker( newStateSchema: List[StateStoreColFamilySchema], ignoreValueSchema: Boolean, stateSchemaVersion: Int): Boolean = { - val existingStateSchemaList = getExistingKeyAndValueSchema().sortBy(_.colFamilyName) - val newStateSchemaList = newStateSchema.sortBy(_.colFamilyName) + val existingStateSchemaList = getExistingKeyAndValueSchema() + val newStateSchemaList = newStateSchema if (existingStateSchemaList.isEmpty) { // write the schema file if it doesn't exist - createSchemaFile(newStateSchemaList, stateSchemaVersion) + createSchemaFile(newStateSchemaList.sortBy(_.colFamilyName), stateSchemaVersion) true } else { // validate if the new schema is compatible with the existing schema @@ -186,7 +187,13 @@ class StateSchemaCompatibilityChecker( check(existingStateSchema, newSchema, ignoreValueSchema) } } - false + val colFamiliesAddedOrRemoved = + (newStateSchemaList.map(_.colFamilyName).toSet != existingSchemaMap.keySet) + if (stateSchemaVersion == SCHEMA_FORMAT_V3 && colFamiliesAddedOrRemoved) { + createSchemaFile(newStateSchemaList.sortBy(_.colFamilyName), stateSchemaVersion) + } + // TODO: [SPARK-49535] Write Schema files after schema has changed for StateSchemaV3 + colFamiliesAddedOrRemoved } } @@ -195,6 +202,9 @@ class StateSchemaCompatibilityChecker( } object StateSchemaCompatibilityChecker { + + val SCHEMA_FORMAT_V3: Int = 3 + private def disallowBinaryInequalityColumn(schema: StructType): Unit = { if (!UnsafeRowUtils.isBinaryStable(schema)) { throw new SparkUnsupportedOperationException( @@ -274,10 +284,31 @@ object StateSchemaCompatibilityChecker { if (storeConf.stateSchemaCheckEnabled && result.isDefined) { throw result.get } - val schemaFileLocation = newSchemaFilePath match { - case Some(path) => path.toString - case None => checker.schemaFileLocation.toString + val schemaFileLocation = if (evolvedSchema) { + // if we are using the state schema v3, and we have + // evolved schema, this newSchemaFilePath should be defined + // and we want to populate the metadata with this file + if (stateSchemaVersion == SCHEMA_FORMAT_V3) { + newSchemaFilePath.get.toString + } else { + // if we are using any version less than v3, we have written + // the schema to this static location, which we will return + checker.schemaFileLocation.toString + } + } else { + // if we have not evolved schema (there has been a previous schema) + // and we are using state schema v3, this file path would be defined + // so we would just populate the next run's metadata file with this + // file path + if (stateSchemaVersion == SCHEMA_FORMAT_V3) { + oldSchemaFilePath.get.toString + } else { + // if we are using any version less than v3, we have written + // the schema to this static location, which we will return + checker.schemaFileLocation.toString + } } + StateSchemaValidationResult(evolvedSchema, schemaFileLocation) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index c092adb354c2d..3cb41710a22c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -166,17 +166,19 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp "number of state store instances") ) ++ stateStoreCustomMetrics ++ pythonMetrics - def stateSchemaFilePath(storeName: Option[String] = None): Path = { - def stateInfo = getStateInfo + // This method is only used to fetch the state schema directory path for + // operators that use StateSchemaV3, as prior versions only use a single + // set file path. + def stateSchemaDirPath( + storeName: Option[String] = None): Path = { val stateCheckpointPath = new Path(getStateInfo.checkpointLocation, - s"${stateInfo.operatorId.toString}") + s"${getStateInfo.operatorId.toString}") storeName match { case Some(storeName) => - val storeNamePath = new Path(stateCheckpointPath, storeName) - new Path(new Path(storeNamePath, "_metadata"), "schema") + new Path(new Path(stateCheckpointPath, "_stateSchema"), storeName) case None => - new Path(new Path(stateCheckpointPath, "_metadata"), "schema") + new Path(new Path(stateCheckpointPath, "_stateSchema"), "default") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala index fd3df372a2d56..192b5bf65c4c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.expressions import org.apache.spark.SparkException import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveBooleanEncoder, ProductEncoder} /** * An aggregator that uses a single associative and commutative reduce function. This reduce @@ -46,10 +47,10 @@ private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T) override def zero: (Boolean, T) = (false, _zero.asInstanceOf[T]) - override def bufferEncoder: Encoder[(Boolean, T)] = - ExpressionEncoder.tuple( - ExpressionEncoder[Boolean](), - encoder.asInstanceOf[ExpressionEncoder[T]]) + override def bufferEncoder: Encoder[(Boolean, T)] = { + ProductEncoder.tuple(Seq(PrimitiveBooleanEncoder, encoder.asInstanceOf[AgnosticEncoder[T]])) + .asInstanceOf[Encoder[(Boolean, T)]] + } override def outputEncoder: Encoder[T] = encoder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 99287bddb5104..0d0258f11efb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.internal import org.apache.spark.annotation.Unstable import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} import org.apache.spark.sql.artifact.ArtifactManager -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, ReplaceCharWithVarchar, ResolveSessionCatalog, TableFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -205,6 +205,8 @@ abstract class BaseSessionStateBuilder( new ResolveSessionCatalog(this.catalogManager) +: ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: + new ResolveTranspose(session) +: + new InvokeProcedures(session) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 8b4fa90b3119d..52b8d35e2fbf8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, LocalRelation, LogicalPlan, OptionList, RecoverPartitions, ShowFunctions, ShowNamespaces, ShowTables, UnresolvedTableSpec, View} import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{CatalogManager, SupportsNamespaces, TableCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{CatalogHelper, MultipartIdentifierHelper, NamespaceHelper, TransformHelper} import org.apache.spark.sql.errors.QueryCompilationErrors @@ -671,12 +672,9 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } else { CatalogTableType.MANAGED } - val location = if (storage.locationUri.isDefined) { - val locationStr = storage.locationUri.get.toString - Some(locationStr) - } else { - None - } + + // The location in UnresolvedTableSpec should be the original user-provided path string. + val location = CaseInsensitiveMap(options).get("path") val newOptions = OptionList(options.map { case (key, value) => (key, Literal(value).asInstanceOf[Expression]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala index 7248a2d3f056e..f0eef9ae1cbb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala @@ -426,8 +426,10 @@ final class DataFrameWriterImpl[T] private[sql](ds: Dataset[T]) extends DataFram import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ val session = df.sparkSession - val canUseV2 = lookupV2Provider().isDefined || - df.sparkSession.sessionState.conf.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined + val canUseV2 = lookupV2Provider().isDefined || (df.sparkSession.sessionState.conf.getConf( + SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined && + !df.sparkSession.sessionState.catalogManager.catalog(CatalogManager.SESSION_CATALOG_NAME) + .isInstanceOf[DelegatingCatalogExtension]) session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { case nameParts @ NonSessionCatalogAndIdentifier(catalog, ident) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala new file mode 100644 index 0000000000000..0a19e6c47afa9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.util + +import scala.collection.mutable +import scala.jdk.CollectionConverters.MapHasAsScala + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.{Column, DataFrame, DataFrameWriterV2, Dataset} +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, UnresolvedFunction, UnresolvedIdentifier, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.connector.catalog.TableWritePrivilege._ +import org.apache.spark.sql.connector.expressions._ +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.types.IntegerType + +/** + * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 API. + * + * @since 3.0.0 + */ +@Experimental +final class DataFrameWriterV2Impl[T] private[sql](table: String, ds: Dataset[T]) + extends DataFrameWriterV2[T] { + + private val df: DataFrame = ds.toDF() + + private val sparkSession = ds.sparkSession + import sparkSession.expression + + private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table) + + private val logicalPlan = df.queryExecution.logical + + private var provider: Option[String] = None + + private val options = new mutable.HashMap[String, String]() + + private val properties = new mutable.HashMap[String, String]() + + private var partitioning: Option[Seq[Transform]] = None + + private var clustering: Option[ClusterByTransform] = None + + /** @inheritdoc */ + override def using(provider: String): this.type = { + this.provider = Some(provider) + this + } + + /** @inheritdoc */ + override def option(key: String, value: String): this.type = { + this.options.put(key, value) + this + } + + /** @inheritdoc */ + override def options(options: scala.collection.Map[String, String]): this.type = { + options.foreach { + case (key, value) => + this.options.put(key, value) + } + this + } + + /** @inheritdoc */ + override def options(options: util.Map[String, String]): this.type = { + this.options(options.asScala) + this + } + + /** @inheritdoc */ + override def tableProperty(property: String, value: String): this.type = { + this.properties.put(property, value) + this + } + + + /** @inheritdoc */ + @scala.annotation.varargs + override def partitionedBy(column: Column, columns: Column*): this.type = { + def ref(name: String): NamedReference = LogicalExpressions.parseReference(name) + + val asTransforms = (column +: columns).map(expression).map { + case PartitionTransform.YEARS(Seq(attr: Attribute)) => + LogicalExpressions.years(ref(attr.name)) + case PartitionTransform.MONTHS(Seq(attr: Attribute)) => + LogicalExpressions.months(ref(attr.name)) + case PartitionTransform.DAYS(Seq(attr: Attribute)) => + LogicalExpressions.days(ref(attr.name)) + case PartitionTransform.HOURS(Seq(attr: Attribute)) => + LogicalExpressions.hours(ref(attr.name)) + case PartitionTransform.BUCKET(Seq(Literal(numBuckets: Int, IntegerType), attr: Attribute)) => + LogicalExpressions.bucket(numBuckets, Array(ref(attr.name))) + case PartitionTransform.BUCKET(Seq(numBuckets, e)) => + throw QueryCompilationErrors.invalidBucketsNumberError(numBuckets.toString, e.toString) + case attr: Attribute => + LogicalExpressions.identity(ref(attr.name)) + case expr => + throw QueryCompilationErrors.invalidPartitionTransformationError(expr) + } + + this.partitioning = Some(asTransforms) + validatePartitioning() + this + } + + /** @inheritdoc */ + @scala.annotation.varargs + override def clusterBy(colName: String, colNames: String*): this.type = { + this.clustering = + Some(ClusterByTransform((colName +: colNames).map(col => FieldReference(col)))) + validatePartitioning() + this + } + + /** + * Validate that clusterBy is not used with partitionBy. + */ + private def validatePartitioning(): Unit = { + if (partitioning.nonEmpty && clustering.nonEmpty) { + throw QueryCompilationErrors.clusterByWithPartitionedBy() + } + } + + /** @inheritdoc */ + override def create(): Unit = { + val tableSpec = UnresolvedTableSpec( + properties = properties.toMap, + provider = provider, + optionExpression = OptionList(Seq.empty), + location = None, + comment = None, + serde = None, + external = false) + runCommand( + CreateTableAsSelect( + UnresolvedIdentifier(tableName), + partitioning.getOrElse(Seq.empty) ++ clustering, + logicalPlan, + tableSpec, + options.toMap, + false)) + } + + /** @inheritdoc */ + override def replace(): Unit = { + internalReplace(orCreate = false) + } + + /** @inheritdoc */ + override def createOrReplace(): Unit = { + internalReplace(orCreate = true) + } + + /** @inheritdoc */ + @throws(classOf[NoSuchTableException]) + def append(): Unit = { + val append = AppendData.byName( + UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT)), + logicalPlan, options.toMap) + runCommand(append) + } + + /** @inheritdoc */ + @throws(classOf[NoSuchTableException]) + def overwrite(condition: Column): Unit = { + val overwrite = OverwriteByExpression.byName( + UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)), + logicalPlan, expression(condition), options.toMap) + runCommand(overwrite) + } + + /** @inheritdoc */ + @throws(classOf[NoSuchTableException]) + def overwritePartitions(): Unit = { + val dynamicOverwrite = OverwritePartitionsDynamic.byName( + UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)), + logicalPlan, options.toMap) + runCommand(dynamicOverwrite) + } + + /** + * Wrap an action to track the QueryExecution and time cost, then report to the user-registered + * callback functions. + */ + private def runCommand(command: LogicalPlan): Unit = { + val qe = new QueryExecution(sparkSession, command, df.queryExecution.tracker) + qe.assertCommandExecuted() + } + + private def internalReplace(orCreate: Boolean): Unit = { + val tableSpec = UnresolvedTableSpec( + properties = properties.toMap, + provider = provider, + optionExpression = OptionList(Seq.empty), + location = None, + comment = None, + serde = None, + external = false) + runCommand(ReplaceTableAsSelect( + UnresolvedIdentifier(tableName), + partitioning.getOrElse(Seq.empty) ++ clustering, + logicalPlan, + tableSpec, + writeOptions = options.toMap, + orCreate = orCreate)) + } +} + +private object PartitionTransform { + class ExtractTransform(name: String) { + private val NAMES = Seq(name) + + def unapply(e: Expression): Option[Seq[Expression]] = e match { + case UnresolvedFunction(NAMES, children, false, None, false, Nil, true) => Option(children) + case _ => None + } + } + + val HOURS = new ExtractTransform("hours") + val DAYS = new ExtractTransform("days") + val MONTHS = new ExtractTransform("months") + val YEARS = new ExtractTransform("years") + val BUCKET = new ExtractTransform("bucket") +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala new file mode 100644 index 0000000000000..bb8146e3e0e33 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import scala.collection.mutable + +import org.apache.spark.SparkRuntimeException +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.{Column, DataFrame, Dataset, MergeIntoWriter} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.functions.expr + +/** + * `MergeIntoWriter` provides methods to define and execute merge actions based + * on specified conditions. + * + * @tparam T the type of data in the Dataset. + * @param table the name of the target table for the merge operation. + * @param ds the source Dataset to merge into the target table. + * @param on the merge condition. + * + * @since 4.0.0 + */ +@Experimental +class MergeIntoWriterImpl[T] private[sql] (table: String, ds: Dataset[T], on: Column) + extends MergeIntoWriter[T] { + + private val df: DataFrame = ds.toDF() + + private[sql] val sparkSession = ds.sparkSession + import sparkSession.RichColumn + + private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table) + + private val logicalPlan = df.queryExecution.logical + + private[sql] val matchedActions = mutable.Buffer.empty[MergeAction] + private[sql] val notMatchedActions = mutable.Buffer.empty[MergeAction] + private[sql] val notMatchedBySourceActions = mutable.Buffer.empty[MergeAction] + + /** @inheritdoc */ + def merge(): Unit = { + if (matchedActions.isEmpty && notMatchedActions.isEmpty && notMatchedBySourceActions.isEmpty) { + throw new SparkRuntimeException( + errorClass = "NO_MERGE_ACTION_SPECIFIED", + messageParameters = Map.empty) + } + + val merge = MergeIntoTable( + UnresolvedRelation(tableName).requireWritePrivileges(MergeIntoTable.getWritePrivileges( + matchedActions, notMatchedActions, notMatchedBySourceActions)), + logicalPlan, + on.expr, + matchedActions.toSeq, + notMatchedActions.toSeq, + notMatchedBySourceActions.toSeq, + schemaEvolutionEnabled) + val qe = sparkSession.sessionState.executePlan(merge) + qe.assertCommandExecuted() + } + + override protected[sql] def insertAll(condition: Option[Column]): MergeIntoWriter[T] = { + this.notMatchedActions += InsertStarAction(condition.map(_.expr)) + this + } + + override protected[sql] def insert( + condition: Option[Column], + map: Map[String, Column]): MergeIntoWriter[T] = { + this.notMatchedActions += InsertAction(condition.map(_.expr), mapToAssignments(map)) + this + } + + override protected[sql] def updateAll( + condition: Option[Column], + notMatchedBySource: Boolean): MergeIntoWriter[T] = { + appendUpdateDeleteAction(UpdateStarAction(condition.map(_.expr)), notMatchedBySource) + } + + override protected[sql] def update( + condition: Option[Column], + map: Map[String, Column], + notMatchedBySource: Boolean): MergeIntoWriter[T] = { + appendUpdateDeleteAction( + UpdateAction(condition.map(_.expr), mapToAssignments(map)), + notMatchedBySource) + } + + override protected[sql] def delete( + condition: Option[Column], + notMatchedBySource: Boolean): MergeIntoWriter[T] = { + appendUpdateDeleteAction(DeleteAction(condition.map(_.expr)), notMatchedBySource) + } + + private def appendUpdateDeleteAction( + action: MergeAction, + notMatchedBySource: Boolean): MergeIntoWriter[T] = { + if (notMatchedBySource) { + notMatchedBySourceActions += action + } else { + matchedActions += action + } + this + } + + private def mapToAssignments(map: Map[String, Column]): Seq[Assignment] = { + map.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala similarity index 51% rename from sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala rename to sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala index ed8cf4f121f03..ca439cdb89958 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.internal import scala.jdk.CollectionConverters._ import org.apache.spark.SPARK_DOC_ROOT import org.apache.spark.annotation.Stable -import org.apache.spark.internal.config.{ConfigEntry, OptionalConfigEntry} +import org.apache.spark.internal.config.ConfigEntry +import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.internal.SQLConf /** * Runtime configuration interface for Spark. To access this, use `SparkSession.conf`. @@ -33,89 +33,26 @@ import org.apache.spark.sql.internal.SQLConf * @since 2.0.0 */ @Stable -class RuntimeConfig private[sql](val sqlConf: SQLConf = new SQLConf) { +class RuntimeConfigImpl private[sql](val sqlConf: SQLConf = new SQLConf) extends RuntimeConfig { - /** - * Sets the given Spark runtime configuration property. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def set(key: String, value: String): Unit = { requireNonStaticConf(key) sqlConf.setConfString(key, value) } - /** - * Sets the given Spark runtime configuration property. - * - * @since 2.0.0 - */ - def set(key: String, value: Boolean): Unit = { - set(key, value.toString) - } - - /** - * Sets the given Spark runtime configuration property. - * - * @since 2.0.0 - */ - def set(key: String, value: Long): Unit = { - set(key, value.toString) - } - - /** - * Sets the given Spark runtime configuration property. - */ - private[sql] def set[T](entry: ConfigEntry[T], value: T): Unit = { - requireNonStaticConf(entry.key) - sqlConf.setConf(entry, value) - } - - /** - * Returns the value of Spark runtime configuration property for the given key. - * - * @throws java.util.NoSuchElementException if the key is not set and does not have a default - * value - * @since 2.0.0 - */ + /** @inheritdoc */ @throws[NoSuchElementException]("if the key is not set") def get(key: String): String = { sqlConf.getConfString(key) } - /** - * Returns the value of Spark runtime configuration property for the given key. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def get(key: String, default: String): String = { sqlConf.getConfString(key, default) } - /** - * Returns the value of Spark runtime configuration property for the given key. - */ - @throws[NoSuchElementException]("if the key is not set") - private[sql] def get[T](entry: ConfigEntry[T]): T = { - sqlConf.getConf(entry) - } - - private[sql] def get[T](entry: OptionalConfigEntry[T]): Option[T] = { - sqlConf.getConf(entry) - } - - /** - * Returns the value of Spark runtime configuration property for the given key. - */ - private[sql] def get[T](entry: ConfigEntry[T], default: T): T = { - sqlConf.getConf(entry, default) - } - - /** - * Returns all properties set in this conf. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def getAll: Map[String, String] = { sqlConf.getAllConfs } @@ -124,36 +61,20 @@ class RuntimeConfig private[sql](val sqlConf: SQLConf = new SQLConf) { getAll.asJava } - /** - * Returns the value of Spark runtime configuration property for the given key. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def getOption(key: String): Option[String] = { try Option(get(key)) catch { case _: NoSuchElementException => None } } - /** - * Resets the configuration property for the given key. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def unset(key: String): Unit = { requireNonStaticConf(key) sqlConf.unsetConf(key) } - /** - * Indicates whether the configuration property with the given key - * is modifiable in the current session. - * - * @return `true` if the configuration property is modifiable. For static SQL, Spark Core, - * invalid (not existing) and other non-modifiable configuration properties, - * the returned value is `false`. - * @since 2.4.0 - */ + /** @inheritdoc */ def isModifiable(key: String): Boolean = sqlConf.isModifiable(key) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala index b6340a35e7703..23ceb8135fa8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.internal +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression @@ -25,10 +27,10 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression private[sql] object TypedAggUtils { def aggKeyColumn[A]( - encoder: ExpressionEncoder[A], + encoder: Encoder[A], groupingAttributes: Seq[Attribute]): NamedExpression = { - if (!encoder.isSerializedAsStructForTopLevel) { - assert(groupingAttributes.length == 1) + val agnosticEncoder = agnosticEncoderFor(encoder) + if (!agnosticEncoder.isStruct) { if (SQLConf.get.nameNonStructGroupingKeyAsValue) { groupingAttributes.head } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index f2b626490d13c..785bf5b13aa78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -46,12 +46,33 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with No // See https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG") ++ distinctUnsupportedAggregateFunctions - private val supportedFunctions = supportedAggregateFunctions + private val supportedFunctions = supportedAggregateFunctions ++ Set("DATE_ADD", "DATE_DIFF") override def isSupportedFunction(funcName: String): Boolean = supportedFunctions.contains(funcName) class MySQLSQLBuilder extends JDBCSQLBuilder { + override def visitExtract(field: String, source: String): String = { + field match { + case "DAY_OF_YEAR" => s"DAYOFYEAR($source)" + case "YEAR_OF_WEEK" => s"EXTRACT(YEAR FROM $source)" + // WEEKDAY uses Monday = 0, Tuesday = 1, ... and ISO standard is Monday = 1, ..., + // so we use the formula (WEEKDAY + 1) to follow the ISO standard. + case "DAY_OF_WEEK" => s"(WEEKDAY($source) + 1)" + case _ => super.visitExtract(field, source) + } + } + + override def visitSQLFunction(funcName: String, inputs: Array[String]): String = { + funcName match { + case "DATE_ADD" => + s"DATE_ADD(${inputs(0)}, INTERVAL ${inputs(1)} DAY)" + case "DATE_DIFF" => + s"DATEDIFF(${inputs(0)}, ${inputs(1)})" + case _ => super.visitSQLFunction(funcName, inputs) + } + } + override def visitSortOrder( sortKey: String, sortDirection: SortDirection, nullOrdering: NullOrdering): String = { (sortDirection, nullOrdering) match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 7085366c3b7a3..af9fd5464277c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -140,13 +140,20 @@ class SingleStatementExec( * Implements recursive iterator logic over all child execution nodes. * @param collection * Collection of child execution nodes. + * @param label + * Label set by user or None otherwise. */ -abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundStatementExec]) +abstract class CompoundNestedStatementIteratorExec( + collection: Seq[CompoundStatementExec], + label: Option[String] = None) extends NonLeafStatementExec { private var localIterator = collection.iterator private var curr = if (localIterator.hasNext) Some(localIterator.next()) else None + /** Used to stop the iteration in cases when LEAVE statement is encountered. */ + private var stopIteration = false + private lazy val treeIterator: Iterator[CompoundStatementExec] = new Iterator[CompoundStatementExec] { override def hasNext: Boolean = { @@ -157,7 +164,7 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState case _ => throw SparkException.internalError( "Unknown statement type encountered during SQL script interpretation.") } - localIterator.hasNext || childHasNext + !stopIteration && (localIterator.hasNext || childHasNext) } @scala.annotation.tailrec @@ -165,12 +172,28 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState curr match { case None => throw SparkException.internalError( "No more elements to iterate through in the current SQL compound statement.") + case Some(leaveStatement: LeaveStatementExec) => + handleLeaveStatement(leaveStatement) + curr = None + leaveStatement + case Some(iterateStatement: IterateStatementExec) => + handleIterateStatement(iterateStatement) + curr = None + iterateStatement case Some(statement: LeafStatementExec) => curr = if (localIterator.hasNext) Some(localIterator.next()) else None statement case Some(body: NonLeafStatementExec) => if (body.getTreeIterator.hasNext) { - body.getTreeIterator.next() + body.getTreeIterator.next() match { + case leaveStatement: LeaveStatementExec => + handleLeaveStatement(leaveStatement) + leaveStatement + case iterateStatement: IterateStatementExec => + handleIterateStatement(iterateStatement) + iterateStatement + case other => other + } } else { curr = if (localIterator.hasNext) Some(localIterator.next()) else None next() @@ -187,6 +210,37 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState collection.foreach(_.reset()) localIterator = collection.iterator curr = if (localIterator.hasNext) Some(localIterator.next()) else None + stopIteration = false + } + + /** Actions to do when LEAVE statement is encountered, to stop the execution of this compound. */ + private def handleLeaveStatement(leaveStatement: LeaveStatementExec): Unit = { + if (!leaveStatement.hasBeenMatched) { + // Stop the iteration. + stopIteration = true + + // TODO: Variable cleanup (once we add SQL script execution logic). + // TODO: Add interpreter tests as well. + + // Check if label has been matched. + leaveStatement.hasBeenMatched = label.isDefined && label.get.equals(leaveStatement.label) + } + } + + /** + * Actions to do when ITERATE statement is encountered, to stop the execution of this compound. + */ + private def handleIterateStatement(iterateStatement: IterateStatementExec): Unit = { + if (!iterateStatement.hasBeenMatched) { + // Stop the iteration. + stopIteration = true + + // TODO: Variable cleanup (once we add SQL script execution logic). + // TODO: Add interpreter tests as well. + + // No need to check if label has been matched, since ITERATE statement is already + // not allowed in CompoundBody. + } } } @@ -194,9 +248,11 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState * Executable node for CompoundBody. * @param statements * Executable nodes for nested statements within the CompoundBody. + * @param label + * Label set by user to CompoundBody or None otherwise. */ -class CompoundBodyExec(statements: Seq[CompoundStatementExec]) - extends CompoundNestedStatementIteratorExec(statements) +class CompoundBodyExec(statements: Seq[CompoundStatementExec], label: Option[String] = None) + extends CompoundNestedStatementIteratorExec(statements, label) /** * Executable node for IfElseStatement. @@ -277,11 +333,13 @@ class IfElseStatementExec( * Executable node for WhileStatement. * @param condition Executable node for the condition. * @param body Executable node for the body. + * @param label Label set to WhileStatement by user or None otherwise. * @param session Spark session that SQL script is executed within. */ class WhileStatementExec( condition: SingleStatementExec, body: CompoundBodyExec, + label: Option[String], session: SparkSession) extends NonLeafStatementExec { private object WhileState extends Enumeration { @@ -308,6 +366,26 @@ class WhileStatementExec( condition case WhileState.Body => val retStmt = body.getTreeIterator.next() + + // Handle LEAVE or ITERATE statement if it has been encountered. + retStmt match { + case leaveStatementExec: LeaveStatementExec if !leaveStatementExec.hasBeenMatched => + if (label.contains(leaveStatementExec.label)) { + leaveStatementExec.hasBeenMatched = true + } + curr = None + return retStmt + case iterStatementExec: IterateStatementExec if !iterStatementExec.hasBeenMatched => + if (label.contains(iterStatementExec.label)) { + iterStatementExec.hasBeenMatched = true + } + state = WhileState.Condition + curr = Some(condition) + condition.reset() + return retStmt + case _ => + } + if (!body.getTreeIterator.hasNext) { state = WhileState.Condition curr = Some(condition) @@ -326,3 +404,191 @@ class WhileStatementExec( body.reset() } } + +/** + * Executable node for CaseStatement. + * @param conditions Collection of executable conditions which correspond to WHEN clauses. + * @param conditionalBodies Collection of executable bodies that have a corresponding condition, + * in WHEN branches. + * @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch. + * @param session Spark session that SQL script is executed within. + */ +class CaseStatementExec( + conditions: Seq[SingleStatementExec], + conditionalBodies: Seq[CompoundBodyExec], + elseBody: Option[CompoundBodyExec], + session: SparkSession) extends NonLeafStatementExec { + private object CaseState extends Enumeration { + val Condition, Body = Value + } + + private var state = CaseState.Condition + private var curr: Option[CompoundStatementExec] = Some(conditions.head) + + private var clauseIdx: Int = 0 + private val conditionsCount = conditions.length + + private lazy val treeIterator: Iterator[CompoundStatementExec] = + new Iterator[CompoundStatementExec] { + override def hasNext: Boolean = curr.nonEmpty + + override def next(): CompoundStatementExec = state match { + case CaseState.Condition => + val condition = curr.get.asInstanceOf[SingleStatementExec] + if (evaluateBooleanCondition(session, condition)) { + state = CaseState.Body + curr = Some(conditionalBodies(clauseIdx)) + } else { + clauseIdx += 1 + if (clauseIdx < conditionsCount) { + // There are WHEN clauses remaining. + state = CaseState.Condition + curr = Some(conditions(clauseIdx)) + } else if (elseBody.isDefined) { + // ELSE clause exists. + state = CaseState.Body + curr = Some(elseBody.get) + } else { + // No remaining clauses. + curr = None + } + } + condition + case CaseState.Body => + assert(curr.get.isInstanceOf[CompoundBodyExec]) + val currBody = curr.get.asInstanceOf[CompoundBodyExec] + val retStmt = currBody.getTreeIterator.next() + if (!currBody.getTreeIterator.hasNext) { + curr = None + } + retStmt + } + } + + override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator + + override def reset(): Unit = { + state = CaseState.Condition + curr = Some(conditions.head) + clauseIdx = 0 + conditions.foreach(c => c.reset()) + conditionalBodies.foreach(b => b.reset()) + elseBody.foreach(b => b.reset()) + } +} + +/** + * Executable node for RepeatStatement. + * @param condition Executable node for the condition - evaluates to a row with a single boolean + * expression, otherwise throws an exception + * @param body Executable node for the body. + * @param label Label set to RepeatStatement by user, None if not set + * @param session Spark session that SQL script is executed within. + */ +class RepeatStatementExec( + condition: SingleStatementExec, + body: CompoundBodyExec, + label: Option[String], + session: SparkSession) extends NonLeafStatementExec { + + private object RepeatState extends Enumeration { + val Condition, Body = Value + } + + private var state = RepeatState.Body + private var curr: Option[CompoundStatementExec] = Some(body) + + private lazy val treeIterator: Iterator[CompoundStatementExec] = + new Iterator[CompoundStatementExec] { + override def hasNext: Boolean = curr.nonEmpty + + override def next(): CompoundStatementExec = state match { + case RepeatState.Condition => + val condition = curr.get.asInstanceOf[SingleStatementExec] + if (!evaluateBooleanCondition(session, condition)) { + state = RepeatState.Body + curr = Some(body) + body.reset() + } else { + curr = None + } + condition + case RepeatState.Body => + val retStmt = body.getTreeIterator.next() + + retStmt match { + case leaveStatementExec: LeaveStatementExec if !leaveStatementExec.hasBeenMatched => + if (label.contains(leaveStatementExec.label)) { + leaveStatementExec.hasBeenMatched = true + } + curr = None + return retStmt + case iterStatementExec: IterateStatementExec if !iterStatementExec.hasBeenMatched => + if (label.contains(iterStatementExec.label)) { + iterStatementExec.hasBeenMatched = true + } + state = RepeatState.Condition + curr = Some(condition) + condition.reset() + return retStmt + case _ => + } + + if (!body.getTreeIterator.hasNext) { + state = RepeatState.Condition + curr = Some(condition) + condition.reset() + } + retStmt + } + } + + override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator + + override def reset(): Unit = { + state = RepeatState.Body + curr = Some(body) + body.reset() + condition.reset() + } +} + +/** + * Executable node for LeaveStatement. + * @param label Label of the compound or loop to leave. + */ +class LeaveStatementExec(val label: String) extends LeafStatementExec { + /** + * Label specified in the LEAVE statement might not belong to the immediate surrounding compound, + * but to the any surrounding compound. + * Iteration logic is recursive, i.e. when iterating through the compound, if another + * compound is encountered, next() will be called to iterate its body. The same logic + * is applied to any other compound down the traversal tree. + * In such cases, when LEAVE statement is encountered (as the leaf of the traversal tree), + * it will be propagated upwards and the logic will try to match it to the labels of + * surrounding compounds. + * Once the match is found, this flag is set to true to indicate that search should be stopped. + */ + var hasBeenMatched: Boolean = false + override def reset(): Unit = hasBeenMatched = false +} + +/** + * Executable node for ITERATE statement. + * @param label Label of the loop to iterate. + */ +class IterateStatementExec(val label: String) extends LeafStatementExec { + /** + * Label specified in the ITERATE statement might not belong to the immediate compound, + * but to the any surrounding compound. + * Iteration logic is recursive, i.e. when iterating through the compound, if another + * compound is encountered, next() will be called to iterate its body. The same logic + * is applied to any other compound down the tree. + * In such cases, when ITERATE statement is encountered (as the leaf of the traversal tree), + * it will be propagated upwards and the logic will try to match it to the labels of + * surrounding compounds. + * Once the match is found, this flag is set to true to indicate that search should be stopped. + */ + var hasBeenMatched: Boolean = false + override def reset(): Unit = hasBeenMatched = false +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 08b4f97286280..917b4d6f45ee0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier -import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, IfElseStatement, SingleStatement, WhileStatement} +import org.apache.spark.sql.catalyst.parser.{CaseStatement, CompoundBody, CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement, RepeatStatement, SingleStatement, WhileStatement} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan} import org.apache.spark.sql.catalyst.trees.Origin @@ -71,9 +71,9 @@ case class SqlScriptingInterpreter() { private def transformTreeIntoExecutable( node: CompoundPlanStatement, session: SparkSession): CompoundStatementExec = node match { - case body: CompoundBody => + case CompoundBody(collection, label) => // TODO [SPARK-48530]: Current logic doesn't support scoped variables and shadowing. - val variables = body.collection.flatMap { + val variables = collection.flatMap { case st: SingleStatement => getDeclareVarNameFromPlan(st.parsedPlan) case _ => None } @@ -82,7 +82,9 @@ case class SqlScriptingInterpreter() { .map(new SingleStatementExec(_, Origin(), isInternal = true)) .reverse new CompoundBodyExec( - body.collection.map(st => transformTreeIntoExecutable(st, session)) ++ dropVariables) + collection.map(st => transformTreeIntoExecutable(st, session)) ++ dropVariables, + label) + case IfElseStatement(conditions, conditionalBodies, elseBody) => val conditionsExec = conditions.map(condition => new SingleStatementExec(condition.parsedPlan, condition.origin, isInternal = false)) @@ -92,12 +94,38 @@ case class SqlScriptingInterpreter() { transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec]) new IfElseStatementExec( conditionsExec, conditionalBodiesExec, unconditionalBodiesExec, session) - case WhileStatement(condition, body, _) => + + case CaseStatement(conditions, conditionalBodies, elseBody) => + val conditionsExec = conditions.map(condition => + // todo: what to put here for isInternal, in case of simple case statement + new SingleStatementExec(condition.parsedPlan, condition.origin, isInternal = false)) + val conditionalBodiesExec = conditionalBodies.map(body => + transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec]) + val unconditionalBodiesExec = elseBody.map(body => + transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec]) + new CaseStatementExec( + conditionsExec, conditionalBodiesExec, unconditionalBodiesExec, session) + + case WhileStatement(condition, body, label) => val conditionExec = new SingleStatementExec(condition.parsedPlan, condition.origin, isInternal = false) val bodyExec = transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec] - new WhileStatementExec(conditionExec, bodyExec, session) + new WhileStatementExec(conditionExec, bodyExec, label, session) + + case RepeatStatement(condition, body, label) => + val conditionExec = + new SingleStatementExec(condition.parsedPlan, condition.origin, isInternal = false) + val bodyExec = + transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec] + new RepeatStatementExec(conditionExec, bodyExec, label, session) + + case leaveStatement: LeaveStatement => + new LeaveStatementExec(leaveStatement.label) + + case iterateStatement: IterateStatement => + new IterateStatementExec(iterateStatement.label) + case sparkStatement: SingleStatement => new SingleStatementExec( sparkStatement.parsedPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index 63d937cb34820..7cf92db59067c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -14,159 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.sql.streaming -import java.util.UUID -import java.util.concurrent.TimeoutException - -import org.apache.spark.annotation.Evolving -import org.apache.spark.sql.SparkSession - -/** - * A handle to a query that is executing continuously in the background as new data arrives. - * All these methods are thread-safe. - * @since 2.0.0 - */ -@Evolving -trait StreamingQuery { - - /** - * Returns the user-specified name of the query, or null if not specified. - * This name can be specified in the `org.apache.spark.sql.streaming.DataStreamWriter` - * as `dataframe.writeStream.queryName("query").start()`. - * This name, if set, must be unique across all active queries. - * - * @since 2.0.0 - */ - def name: String - - /** - * Returns the unique id of this query that persists across restarts from checkpoint data. - * That is, this id is generated when a query is started for the first time, and - * will be the same every time it is restarted from checkpoint data. Also see [[runId]]. - * - * @since 2.1.0 - */ - def id: UUID - - /** - * Returns the unique id of this run of the query. That is, every start/restart of a query will - * generate a unique runId. Therefore, every time a query is restarted from - * checkpoint, it will have the same [[id]] but different [[runId]]s. - */ - def runId: UUID - - /** - * Returns the `SparkSession` associated with `this`. - * - * @since 2.0.0 - */ - def sparkSession: SparkSession - - /** - * Returns `true` if this query is actively running. - * - * @since 2.0.0 - */ - def isActive: Boolean - - /** - * Returns the [[StreamingQueryException]] if the query was terminated by an exception. - * @since 2.0.0 - */ - def exception: Option[StreamingQueryException] - - /** - * Returns the current status of the query. - * - * @since 2.0.2 - */ - def status: StreamingQueryStatus - - /** - * Returns an array of the most recent [[StreamingQueryProgress]] updates for this query. - * The number of progress updates retained for each stream is configured by Spark session - * configuration `spark.sql.streaming.numRecentProgressUpdates`. - * - * @since 2.1.0 - */ - def recentProgress: Array[StreamingQueryProgress] - - /** - * Returns the most recent [[StreamingQueryProgress]] update of this streaming query. - * - * @since 2.1.0 - */ - def lastProgress: StreamingQueryProgress - - /** - * Waits for the termination of `this` query, either by `query.stop()` or by an exception. - * If the query has terminated with an exception, then the exception will be thrown. - * - * If the query has terminated, then all subsequent calls to this method will either return - * immediately (if the query was terminated by `stop()`), or throw the exception - * immediately (if the query has terminated with exception). - * - * @throws StreamingQueryException if the query has terminated with an exception. - * - * @since 2.0.0 - */ - @throws[StreamingQueryException] - def awaitTermination(): Unit - - /** - * Waits for the termination of `this` query, either by `query.stop()` or by an exception. - * If the query has terminated with an exception, then the exception will be thrown. - * Otherwise, it returns whether the query has terminated or not within the `timeoutMs` - * milliseconds. - * - * If the query has terminated, then all subsequent calls to this method will either return - * `true` immediately (if the query was terminated by `stop()`), or throw the exception - * immediately (if the query has terminated with exception). - * - * @throws StreamingQueryException if the query has terminated with an exception - * - * @since 2.0.0 - */ - @throws[StreamingQueryException] - def awaitTermination(timeoutMs: Long): Boolean - - /** - * Blocks until all available data in the source has been processed and committed to the sink. - * This method is intended for testing. Note that in the case of continually arriving data, this - * method may block forever. Additionally, this method is only guaranteed to block until data that - * has been synchronously appended data to a `org.apache.spark.sql.execution.streaming.Source` - * prior to invocation. (i.e. `getOffset` must immediately reflect the addition). - * @since 2.0.0 - */ - def processAllAvailable(): Unit - - /** - * Stops the execution of this query if it is running. This waits until the termination of the - * query execution threads or until a timeout is hit. - * - * By default stop will block indefinitely. You can configure a timeout by the configuration - * `spark.sql.streaming.stopTimeout`. A timeout of 0 (or negative) milliseconds will block - * indefinitely. If a `TimeoutException` is thrown, users can retry stopping the stream. If the - * issue persists, it is advisable to kill the Spark application. - * - * @since 2.0.0 - */ - @throws[TimeoutException] - def stop(): Unit - - /** - * Prints the physical plan to the console for debugging purposes. - * @since 2.0.0 - */ - def explain(): Unit +import org.apache.spark.sql.{api, SparkSession} - /** - * Prints the physical plan to the console for debugging purposes. - * - * @param extended whether to do extended explain or not - * @since 2.0.0 - */ - def explain(extended: Boolean): Unit +/** @inheritdoc */ +trait StreamingQuery extends api.StreamingQuery { + /** @inheritdoc */ + override def sparkSession: SparkSession } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala deleted file mode 100644 index c1ceed048ae2c..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ /dev/null @@ -1,276 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import java.util.UUID - -import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} -import com.fasterxml.jackson.module.scala.{ClassTagExtensions, DefaultScalaModule} -import org.json4s.{JObject, JString} -import org.json4s.JsonAST.JValue -import org.json4s.JsonDSL.{jobject2assoc, pair2Assoc} -import org.json4s.jackson.JsonMethods.{compact, render} - -import org.apache.spark.annotation.Evolving -import org.apache.spark.scheduler.SparkListenerEvent - -/** - * Interface for listening to events related to [[StreamingQuery StreamingQueries]]. - * @note The methods are not thread-safe as they may be called from different threads. - * - * @since 2.0.0 - */ -@Evolving -abstract class StreamingQueryListener extends Serializable { - - import StreamingQueryListener._ - - /** - * Called when a query is started. - * @note This is called synchronously with - * [[org.apache.spark.sql.streaming.DataStreamWriter `DataStreamWriter.start()`]], - * that is, `onQueryStart` will be called on all listeners before - * `DataStreamWriter.start()` returns the corresponding [[StreamingQuery]]. Please - * don't block this method as it will block your query. - * @since 2.0.0 - */ - def onQueryStarted(event: QueryStartedEvent): Unit - - /** - * Called when there is some status update (ingestion rate updated, etc.) - * - * @note This method is asynchronous. The status in [[StreamingQuery]] will always be - * latest no matter when this method is called. Therefore, the status of [[StreamingQuery]] - * may be changed before/when you process the event. E.g., you may find [[StreamingQuery]] - * is terminated when you are processing `QueryProgressEvent`. - * @since 2.0.0 - */ - def onQueryProgress(event: QueryProgressEvent): Unit - - /** - * Called when the query is idle and waiting for new data to process. - * @since 3.5.0 - */ - def onQueryIdle(event: QueryIdleEvent): Unit = {} - - /** - * Called when a query is stopped, with or without error. - * @since 2.0.0 - */ - def onQueryTerminated(event: QueryTerminatedEvent): Unit -} - -/** - * Py4J allows a pure interface so this proxy is required. - */ -private[spark] trait PythonStreamingQueryListener { - import StreamingQueryListener._ - - def onQueryStarted(event: QueryStartedEvent): Unit - - def onQueryProgress(event: QueryProgressEvent): Unit - - def onQueryIdle(event: QueryIdleEvent): Unit - - def onQueryTerminated(event: QueryTerminatedEvent): Unit -} - -private[spark] class PythonStreamingQueryListenerWrapper( - listener: PythonStreamingQueryListener) extends StreamingQueryListener { - import StreamingQueryListener._ - - def onQueryStarted(event: QueryStartedEvent): Unit = listener.onQueryStarted(event) - - def onQueryProgress(event: QueryProgressEvent): Unit = listener.onQueryProgress(event) - - override def onQueryIdle(event: QueryIdleEvent): Unit = listener.onQueryIdle(event) - - def onQueryTerminated(event: QueryTerminatedEvent): Unit = listener.onQueryTerminated(event) -} - -/** - * Companion object of [[StreamingQueryListener]] that defines the listener events. - * @since 2.0.0 - */ -@Evolving -object StreamingQueryListener extends Serializable { - - /** - * Base type of [[StreamingQueryListener]] events - * @since 2.0.0 - */ - @Evolving - trait Event extends SparkListenerEvent - - /** - * Event representing the start of a query - * @param id A unique query id that persists across restarts. See `StreamingQuery.id()`. - * @param runId A query id that is unique for every start/restart. See `StreamingQuery.runId()`. - * @param name User-specified name of the query, null if not specified. - * @param timestamp The timestamp to start a query. - * @since 2.1.0 - */ - @Evolving - class QueryStartedEvent private[sql]( - val id: UUID, - val runId: UUID, - val name: String, - val timestamp: String) extends Event with Serializable { - - def json: String = compact(render(jsonValue)) - - private def jsonValue: JValue = { - ("id" -> JString(id.toString)) ~ - ("runId" -> JString(runId.toString)) ~ - ("name" -> JString(name)) ~ - ("timestamp" -> JString(timestamp)) - } - } - - private[spark] object QueryStartedEvent { - private val mapper = { - val ret = new ObjectMapper() with ClassTagExtensions - ret.registerModule(DefaultScalaModule) - ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) - ret - } - - private[spark] def jsonString(event: QueryStartedEvent): String = - mapper.writeValueAsString(event) - - private[spark] def fromJson(json: String): QueryStartedEvent = - mapper.readValue[QueryStartedEvent](json) - } - - /** - * Event representing any progress updates in a query. - * @param progress The query progress updates. - * @since 2.1.0 - */ - @Evolving - class QueryProgressEvent private[sql](val progress: StreamingQueryProgress) extends Event - with Serializable { - - def json: String = compact(render(jsonValue)) - - private def jsonValue: JValue = JObject("progress" -> progress.jsonValue) - } - - private[spark] object QueryProgressEvent { - private val mapper = { - val ret = new ObjectMapper() with ClassTagExtensions - ret.registerModule(DefaultScalaModule) - ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) - ret - } - - private[spark] def jsonString(event: QueryProgressEvent): String = - mapper.writeValueAsString(event) - - private[spark] def fromJson(json: String): QueryProgressEvent = - mapper.readValue[QueryProgressEvent](json) - } - - /** - * Event representing that query is idle and waiting for new data to process. - * - * @param id A unique query id that persists across restarts. See `StreamingQuery.id()`. - * @param runId A query id that is unique for every start/restart. See `StreamingQuery.runId()`. - * @param timestamp The timestamp when the latest no-batch trigger happened. - * @since 3.5.0 - */ - @Evolving - class QueryIdleEvent private[sql]( - val id: UUID, - val runId: UUID, - val timestamp: String) extends Event with Serializable { - - def json: String = compact(render(jsonValue)) - - private def jsonValue: JValue = { - ("id" -> JString(id.toString)) ~ - ("runId" -> JString(runId.toString)) ~ - ("timestamp" -> JString(timestamp)) - } - } - - private[spark] object QueryIdleEvent { - private val mapper = { - val ret = new ObjectMapper() with ClassTagExtensions - ret.registerModule(DefaultScalaModule) - ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) - ret - } - - private[spark] def jsonString(event: QueryTerminatedEvent): String = - mapper.writeValueAsString(event) - - private[spark] def fromJson(json: String): QueryTerminatedEvent = - mapper.readValue[QueryTerminatedEvent](json) - } - - /** - * Event representing that termination of a query. - * - * @param id A unique query id that persists across restarts. See `StreamingQuery.id()`. - * @param runId A query id that is unique for every start/restart. See `StreamingQuery.runId()`. - * @param exception The exception message of the query if the query was terminated - * with an exception. Otherwise, it will be `None`. - * @param errorClassOnException The error class from the exception if the query was terminated - * with an exception which is a part of error class framework. - * If the query was terminated without an exception, or the - * exception is not a part of error class framework, it will be - * `None`. - * @since 2.1.0 - */ - @Evolving - class QueryTerminatedEvent private[sql]( - val id: UUID, - val runId: UUID, - val exception: Option[String], - val errorClassOnException: Option[String]) extends Event with Serializable { - // compatibility with versions in prior to 3.5.0 - def this(id: UUID, runId: UUID, exception: Option[String]) = { - this(id, runId, exception, None) - } - - def json: String = compact(render(jsonValue)) - - private def jsonValue: JValue = { - ("id" -> JString(id.toString)) ~ - ("runId" -> JString(runId.toString)) ~ - ("exception" -> JString(exception.orNull)) ~ - ("errorClassOnException" -> JString(errorClassOnException.orNull)) - } - } - - private[spark] object QueryTerminatedEvent { - private val mapper = { - val ret = new ObjectMapper() with ClassTagExtensions - ret.registerModule(DefaultScalaModule) - ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) - ret - } - - private[spark] def jsonString(event: QueryTerminatedEvent): String = - mapper.writeValueAsString(event) - - private[spark] def fromJson(json: String): QueryTerminatedEvent = - mapper.readValue[QueryTerminatedEvent](json) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 55d2e639a56b1..3ab6d02f6b515 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -364,7 +364,7 @@ class StreamingQueryManager private[sql] ( .orElse(activeQueries.get(query.id)) // shouldn't be needed but paranoia ... val shouldStopActiveRun = - sparkSession.conf.get(SQLConf.STREAMING_STOP_ACTIVE_RUN_ON_RESTART) + sparkSession.sessionState.conf.getConf(SQLConf.STREAMING_STOP_ACTIVE_RUN_ON_RESTART) if (activeOption.isDefined) { if (shouldStopActiveRun) { val oldQuery = activeOption.get diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala deleted file mode 100644 index fe187917ec021..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import org.json4s._ -import org.json4s.JsonAST.JValue -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods._ - -import org.apache.spark.annotation.Evolving - -/** - * Reports information about the instantaneous status of a streaming query. - * - * @param message A human readable description of what the stream is currently doing. - * @param isDataAvailable True when there is new data to be processed. Doesn't apply - * to ContinuousExecution where it is always false. - * @param isTriggerActive True when the trigger is actively firing, false when waiting for the - * next trigger time. Doesn't apply to ContinuousExecution where it is - * always false. - * - * @since 2.1.0 - */ -@Evolving -class StreamingQueryStatus protected[sql]( - val message: String, - val isDataAvailable: Boolean, - val isTriggerActive: Boolean) extends Serializable { - - /** The compact JSON representation of this status. */ - def json: String = compact(render(jsonValue)) - - /** The pretty (i.e. indented) JSON representation of this status. */ - def prettyJson: String = pretty(render(jsonValue)) - - override def toString: String = prettyJson - - private[sql] def copy( - message: String = this.message, - isDataAvailable: Boolean = this.isDataAvailable, - isTriggerActive: Boolean = this.isTriggerActive): StreamingQueryStatus = { - new StreamingQueryStatus( - message = message, - isDataAvailable = isDataAvailable, - isTriggerActive = isTriggerActive) - } - - private[sql] def jsonValue: JValue = { - ("message" -> JString(message)) ~ - ("isDataAvailable" -> JBool(isDataAvailable)) ~ - ("isTriggerActive" -> JBool(isTriggerActive)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala deleted file mode 100644 index 05323d9d03811..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ /dev/null @@ -1,301 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import java.{util => ju} -import java.lang.{Long => JLong} -import java.util.UUID - -import scala.jdk.CollectionConverters._ -import scala.util.control.NonFatal - -import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} -import com.fasterxml.jackson.databind.annotation.JsonDeserialize -import com.fasterxml.jackson.module.scala.{ClassTagExtensions, DefaultScalaModule} -import org.json4s._ -import org.json4s.JsonAST.JValue -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods._ - -import org.apache.spark.annotation.Evolving -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema -import org.apache.spark.sql.streaming.SafeJsonSerializer.{safeDoubleToJValue, safeMapToJValue} -import org.apache.spark.sql.streaming.SinkProgress.DEFAULT_NUM_OUTPUT_ROWS - -/** - * Information about updates made to stateful operators in a [[StreamingQuery]] during a trigger. - */ -@Evolving -class StateOperatorProgress private[spark]( - val operatorName: String, - val numRowsTotal: Long, - val numRowsUpdated: Long, - val allUpdatesTimeMs: Long, - val numRowsRemoved: Long, - val allRemovalsTimeMs: Long, - val commitTimeMs: Long, - val memoryUsedBytes: Long, - val numRowsDroppedByWatermark: Long, - val numShufflePartitions: Long, - val numStateStoreInstances: Long, - val customMetrics: ju.Map[String, JLong] = new ju.HashMap() - ) extends Serializable { - - /** The compact JSON representation of this progress. */ - def json: String = compact(render(jsonValue)) - - /** The pretty (i.e. indented) JSON representation of this progress. */ - def prettyJson: String = pretty(render(jsonValue)) - - private[sql] def copy( - newNumRowsUpdated: Long, - newNumRowsDroppedByWatermark: Long): StateOperatorProgress = - new StateOperatorProgress( - operatorName = operatorName, numRowsTotal = numRowsTotal, numRowsUpdated = newNumRowsUpdated, - allUpdatesTimeMs = allUpdatesTimeMs, numRowsRemoved = numRowsRemoved, - allRemovalsTimeMs = allRemovalsTimeMs, commitTimeMs = commitTimeMs, - memoryUsedBytes = memoryUsedBytes, numRowsDroppedByWatermark = newNumRowsDroppedByWatermark, - numShufflePartitions = numShufflePartitions, numStateStoreInstances = numStateStoreInstances, - customMetrics = customMetrics) - - private[sql] def jsonValue: JValue = { - ("operatorName" -> JString(operatorName)) ~ - ("numRowsTotal" -> JInt(numRowsTotal)) ~ - ("numRowsUpdated" -> JInt(numRowsUpdated)) ~ - ("allUpdatesTimeMs" -> JInt(allUpdatesTimeMs)) ~ - ("numRowsRemoved" -> JInt(numRowsRemoved)) ~ - ("allRemovalsTimeMs" -> JInt(allRemovalsTimeMs)) ~ - ("commitTimeMs" -> JInt(commitTimeMs)) ~ - ("memoryUsedBytes" -> JInt(memoryUsedBytes)) ~ - ("numRowsDroppedByWatermark" -> JInt(numRowsDroppedByWatermark)) ~ - ("numShufflePartitions" -> JInt(numShufflePartitions)) ~ - ("numStateStoreInstances" -> JInt(numStateStoreInstances)) ~ - ("customMetrics" -> { - if (!customMetrics.isEmpty) { - val keys = customMetrics.keySet.asScala.toSeq.sorted - keys.map { k => k -> JInt(customMetrics.get(k).toLong) : JObject }.reduce(_ ~ _) - } else { - JNothing - } - }) - } - - override def toString: String = prettyJson -} - -/** - * Information about progress made in the execution of a [[StreamingQuery]] during - * a trigger. Each event relates to processing done for a single trigger of the streaming - * query. Events are emitted even when no new data is available to be processed. - * - * @param id A unique query id that persists across restarts. See `StreamingQuery.id()`. - * @param runId A query id that is unique for every start/restart. See `StreamingQuery.runId()`. - * @param name User-specified name of the query, null if not specified. - * @param timestamp Beginning time of the trigger in ISO8601 format, i.e. UTC timestamps. - * @param batchId A unique id for the current batch of data being processed. Note that in the - * case of retries after a failure a given batchId my be executed more than once. - * Similarly, when there is no data to be processed, the batchId will not be - * incremented. - * @param batchDuration The process duration of each batch. - * @param durationMs The amount of time taken to perform various operations in milliseconds. - * @param eventTime Statistics of event time seen in this batch. It may contain the following keys: - * {{{ - * "max" -> "2016-12-05T20:54:20.827Z" // maximum event time seen in this trigger - * "min" -> "2016-12-05T20:54:20.827Z" // minimum event time seen in this trigger - * "avg" -> "2016-12-05T20:54:20.827Z" // average event time seen in this trigger - * "watermark" -> "2016-12-05T20:54:20.827Z" // watermark used in this trigger - * }}} - * All timestamps are in ISO8601 format, i.e. UTC timestamps. - * @param stateOperators Information about operators in the query that store state. - * @param sources detailed statistics on data being read from each of the streaming sources. - * @since 2.1.0 - */ -@Evolving -class StreamingQueryProgress private[spark]( - val id: UUID, - val runId: UUID, - val name: String, - val timestamp: String, - val batchId: Long, - val batchDuration: Long, - val durationMs: ju.Map[String, JLong], - val eventTime: ju.Map[String, String], - val stateOperators: Array[StateOperatorProgress], - val sources: Array[SourceProgress], - val sink: SinkProgress, - @JsonDeserialize(contentAs = classOf[GenericRowWithSchema]) - val observedMetrics: ju.Map[String, Row]) extends Serializable { - - /** The aggregate (across all sources) number of records processed in a trigger. */ - def numInputRows: Long = sources.map(_.numInputRows).sum - - /** The aggregate (across all sources) rate of data arriving. */ - def inputRowsPerSecond: Double = sources.map(_.inputRowsPerSecond).sum - - /** The aggregate (across all sources) rate at which Spark is processing data. */ - def processedRowsPerSecond: Double = sources.map(_.processedRowsPerSecond).sum - - /** The compact JSON representation of this progress. */ - def json: String = compact(render(jsonValue)) - - /** The pretty (i.e. indented) JSON representation of this progress. */ - def prettyJson: String = pretty(render(jsonValue)) - - override def toString: String = prettyJson - - private[sql] def jsonValue: JValue = { - ("id" -> JString(id.toString)) ~ - ("runId" -> JString(runId.toString)) ~ - ("name" -> JString(name)) ~ - ("timestamp" -> JString(timestamp)) ~ - ("batchId" -> JInt(batchId)) ~ - ("batchDuration" -> JInt(batchDuration)) ~ - ("numInputRows" -> JInt(numInputRows)) ~ - ("inputRowsPerSecond" -> safeDoubleToJValue(inputRowsPerSecond)) ~ - ("processedRowsPerSecond" -> safeDoubleToJValue(processedRowsPerSecond)) ~ - ("durationMs" -> safeMapToJValue[JLong](durationMs, v => JInt(v.toLong))) ~ - ("eventTime" -> safeMapToJValue[String](eventTime, s => JString(s))) ~ - ("stateOperators" -> JArray(stateOperators.map(_.jsonValue).toList)) ~ - ("sources" -> JArray(sources.map(_.jsonValue).toList)) ~ - ("sink" -> sink.jsonValue) ~ - ("observedMetrics" -> safeMapToJValue[Row](observedMetrics, row => row.jsonValue)) - } -} - -private[spark] object StreamingQueryProgress { - private[this] val mapper = { - val ret = new ObjectMapper() with ClassTagExtensions - ret.registerModule(DefaultScalaModule) - ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) - ret - } - - private[spark] def jsonString(progress: StreamingQueryProgress): String = - mapper.writeValueAsString(progress) - - private[spark] def fromJson(json: String): StreamingQueryProgress = - mapper.readValue[StreamingQueryProgress](json) -} - -/** - * Information about progress made for a source in the execution of a [[StreamingQuery]] - * during a trigger. See [[StreamingQueryProgress]] for more information. - * - * @param description Description of the source. - * @param startOffset The starting offset for data being read. - * @param endOffset The ending offset for data being read. - * @param latestOffset The latest offset from this source. - * @param numInputRows The number of records read from this source. - * @param inputRowsPerSecond The rate at which data is arriving from this source. - * @param processedRowsPerSecond The rate at which data from this source is being processed by - * Spark. - * @since 2.1.0 - */ -@Evolving -class SourceProgress protected[spark]( - val description: String, - val startOffset: String, - val endOffset: String, - val latestOffset: String, - val numInputRows: Long, - val inputRowsPerSecond: Double, - val processedRowsPerSecond: Double, - val metrics: ju.Map[String, String] = Map[String, String]().asJava) extends Serializable { - - /** The compact JSON representation of this progress. */ - def json: String = compact(render(jsonValue)) - - /** The pretty (i.e. indented) JSON representation of this progress. */ - def prettyJson: String = pretty(render(jsonValue)) - - override def toString: String = prettyJson - - private[sql] def jsonValue: JValue = { - ("description" -> JString(description)) ~ - ("startOffset" -> tryParse(startOffset)) ~ - ("endOffset" -> tryParse(endOffset)) ~ - ("latestOffset" -> tryParse(latestOffset)) ~ - ("numInputRows" -> JInt(numInputRows)) ~ - ("inputRowsPerSecond" -> safeDoubleToJValue(inputRowsPerSecond)) ~ - ("processedRowsPerSecond" -> safeDoubleToJValue(processedRowsPerSecond)) ~ - ("metrics" -> safeMapToJValue[String](metrics, s => JString(s))) - } - - private def tryParse(json: String) = try { - parse(json) - } catch { - case NonFatal(e) => JString(json) - } -} - -/** - * Information about progress made for a sink in the execution of a [[StreamingQuery]] - * during a trigger. See [[StreamingQueryProgress]] for more information. - * - * @param description Description of the source corresponding to this status. - * @param numOutputRows Number of rows written to the sink or -1 for Continuous Mode (temporarily) - * or Sink V1 (until decommissioned). - * @since 2.1.0 - */ -@Evolving -class SinkProgress protected[spark]( - val description: String, - val numOutputRows: Long, - val metrics: ju.Map[String, String] = Map[String, String]().asJava) extends Serializable { - - /** SinkProgress without custom metrics. */ - protected[sql] def this(description: String) = { - this(description, DEFAULT_NUM_OUTPUT_ROWS) - } - - /** The compact JSON representation of this progress. */ - def json: String = compact(render(jsonValue)) - - /** The pretty (i.e. indented) JSON representation of this progress. */ - def prettyJson: String = pretty(render(jsonValue)) - - override def toString: String = prettyJson - - private[sql] def jsonValue: JValue = { - ("description" -> JString(description)) ~ - ("numOutputRows" -> JInt(numOutputRows)) ~ - ("metrics" -> safeMapToJValue[String](metrics, s => JString(s))) - } -} - -private[sql] object SinkProgress { - val DEFAULT_NUM_OUTPUT_ROWS: Long = -1L - - def apply(description: String, numOutputRows: Option[Long], - metrics: ju.Map[String, String] = Map[String, String]().asJava): SinkProgress = - new SinkProgress(description, numOutputRows.getOrElse(DEFAULT_NUM_OUTPUT_ROWS), metrics) -} - -private object SafeJsonSerializer { - def safeDoubleToJValue(value: Double): JValue = { - if (value.isNaN || value.isInfinity) JNothing else JDouble(value) - } - - /** Convert map to JValue while handling empty maps. Also, this sorts the keys. */ - def safeMapToJValue[T](map: ju.Map[String, T], valueToJValue: T => JValue): JValue = { - if (map == null || map.isEmpty) return JNothing - val keys = map.asScala.keySet.toSeq.sorted - keys.map { k => k -> valueToJValue(map.get(k)) : JObject }.reduce(_ ~ _) - } -} diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 316e5e9676723..5ad1380e1fb82 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -265,6 +265,7 @@ | org.apache.spark.sql.catalyst.expressions.RaiseErrorExpressionBuilder | raise_error | SELECT raise_error('custom error message') | struct | | org.apache.spark.sql.catalyst.expressions.Rand | rand | SELECT rand() | struct | | org.apache.spark.sql.catalyst.expressions.Rand | random | SELECT random() | struct | +| org.apache.spark.sql.catalyst.expressions.RandStr | randstr | SELECT randstr(3, 0) AS result | struct | | org.apache.spark.sql.catalyst.expressions.Randn | randn | SELECT randn() | struct | | org.apache.spark.sql.catalyst.expressions.Rank | rank | SELECT a, b, rank(b) OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct | | org.apache.spark.sql.catalyst.expressions.RegExpCount | regexp_count | SELECT regexp_count('Steven Jones and Stephen Smith are the best players', 'Ste(v|ph)en') | struct | @@ -367,6 +368,7 @@ | org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT negative(1) | struct | | org.apache.spark.sql.catalyst.expressions.UnaryPositive | positive | SELECT positive(1) | struct<(+ 1):int> | | org.apache.spark.sql.catalyst.expressions.Unhex | unhex | SELECT decode(unhex('537061726B2053514C'), 'UTF-8') | struct | +| org.apache.spark.sql.catalyst.expressions.Uniform | uniform | SELECT uniform(10, 20, 0) > 0 AS result | struct | | org.apache.spark.sql.catalyst.expressions.UnixDate | unix_date | SELECT unix_date(DATE("1970-01-02")) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMicros | unix_micros | SELECT unix_micros(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMillis | unix_millis | SELECT unix_millis(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | @@ -451,6 +453,7 @@ | org.apache.spark.sql.catalyst.expressions.variant.ParseJsonExpressionBuilder | parse_json | SELECT parse_json('{"a":1,"b":0.8}') | struct | | org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariant | schema_of_variant | SELECT schema_of_variant(parse_json('null')) | struct | | org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariantAgg | schema_of_variant_agg | SELECT schema_of_variant_agg(parse_json(j)) FROM VALUES ('1'), ('2'), ('3') AS tab(j) | struct | +| org.apache.spark.sql.catalyst.expressions.variant.ToVariantObject | to_variant_object | SELECT to_variant_object(named_struct('a', 1, 'b', 2)) | struct | | org.apache.spark.sql.catalyst.expressions.variant.TryParseJsonExpressionBuilder | try_parse_json | SELECT try_parse_json('{"a":1,"b":0.8}') | struct | | org.apache.spark.sql.catalyst.expressions.variant.TryVariantGetExpressionBuilder | try_variant_get | SELECT try_variant_get(parse_json('{"a": 1}'), '$.a', 'int') | struct | | org.apache.spark.sql.catalyst.expressions.variant.VariantGetExpressionBuilder | variant_get | SELECT variant_get(parse_json('{"a": 1}'), '$.a', 'int') | struct | diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/array.sql.out index 57108c4582f45..53595d1b8a3eb 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/array.sql.out @@ -194,25 +194,8 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select sort_array(array('b', 'd'), cast(NULL as boolean)) -- !query analysis -org.apache.spark.sql.catalyst.ExtendedAnalysisException -{ - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - "sqlState" : "42K09", - "messageParameters" : { - "inputSql" : "\"CAST(NULL AS BOOLEAN)\"", - "inputType" : "\"BOOLEAN\"", - "paramIndex" : "second", - "requiredType" : "\"BOOLEAN\"", - "sqlExpr" : "\"sort_array(array(b, d), CAST(NULL AS BOOLEAN))\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 57, - "fragment" : "sort_array(array('b', 'd'), cast(NULL as boolean))" - } ] -} +Project [sort_array(array(b, d), cast(null as boolean)) AS sort_array(array(b, d), CAST(NULL AS BOOLEAN))#x] ++- OneRowRelation -- !query diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/date.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/date.sql.out index fd927b99c6456..0e4d2d4e99e26 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/date.sql.out @@ -736,7 +736,7 @@ Project [to_date(26/October/2015, Some(dd/MMMMM/yyyy), Some(America/Los_Angeles) -- !query select from_json('{"d":"26/October/2015"}', 'd Date', map('dateFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"d":"26/October/2015"})#x] +Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"d":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out index 12756576ded9b..b0d128c4cab69 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out @@ -981,9 +981,13 @@ select interval '20 15:40:32.99899999' day to hour -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR` when cast to interval day to hour: 20 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "20 15:40:32.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR`", + "typeName" : "interval day to hour" }, "queryContext" : [ { "objectType" : "", @@ -1000,9 +1004,13 @@ select interval '20 15:40:32.99899999' day to minute -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h:m`, `INTERVAL [+|-]'[+|-]d h:m' DAY TO MINUTE` when cast to interval day to minute: 20 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "20 15:40:32.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h:m`, `INTERVAL [+|-]'[+|-]d h:m' DAY TO MINUTE`", + "typeName" : "interval day to minute" }, "queryContext" : [ { "objectType" : "", @@ -1019,9 +1027,13 @@ select interval '15:40:32.99899999' hour to minute -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE` when cast to interval hour to minute: 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "15:40:32.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE`", + "typeName" : "interval hour to minute" }, "queryContext" : [ { "objectType" : "", @@ -1038,9 +1050,13 @@ select interval '15:40.99899999' hour to second -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND` when cast to interval hour to second: 15:40.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "15:40.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND`", + "typeName" : "interval hour to second" }, "queryContext" : [ { "objectType" : "", @@ -1057,9 +1073,13 @@ select interval '15:40' hour to second -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND` when cast to interval hour to second: 15:40, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "15:40", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND`", + "typeName" : "interval hour to second" }, "queryContext" : [ { "objectType" : "", @@ -1076,9 +1096,13 @@ select interval '20 40:32.99899999' minute to second -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND` when cast to interval minute to second: 20 40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "20 40:32.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND`", + "typeName" : "interval minute to second" }, "queryContext" : [ { "objectType" : "", @@ -1460,9 +1484,11 @@ SELECT INTERVAL '178956970-8' YEAR TO MONTH -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.INTERVAL_PARSING", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Error parsing interval year-month string: integer overflow" + "input" : "178956970-8", + "interval" : "year-month" }, "queryContext" : [ { "objectType" : "", @@ -1909,9 +1935,13 @@ select interval '-\t2-2\t' year to month -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match year-month format of `[+|-]y-m`, `INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH` when cast to interval year to month: -\t2-2\t" + "input" : "-\t2-2\t", + "intervalStr" : "year-month", + "supportedFormat" : "`[+|-]y-m`, `INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH`", + "typeName" : "interval year to month" }, "queryContext" : [ { "objectType" : "", @@ -1935,9 +1965,13 @@ select interval '\n-\t10\t 12:34:46.789\t' day to second -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h:m:s.n`, `INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND` when cast to interval day to second: \n-\t10\t 12:34:46.789\t, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "\n-\t10\t 12:34:46.789\t", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h:m:s.n`, `INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND`", + "typeName" : "interval day to second" }, "queryContext" : [ { "objectType" : "", @@ -2074,7 +2108,7 @@ SELECT to_csv(named_struct('a', interval 32 year, 'b', interval 10 month)), from_csv(to_csv(named_struct('a', interval 32 year, 'b', interval 10 month)), 'a interval year, b interval month') -- !query analysis -Project [from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles)) AS from_json({"a":"1 days"})#x, from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None) AS from_csv(1, 1)#x, to_json(from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1 days"}))#x, to_csv(from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None), Some(America/Los_Angeles)) AS to_csv(from_csv(1, 1))#x, to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)) AS to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH))#x, from_csv(StructField(a,YearMonthIntervalType(0,0),true), StructField(b,YearMonthIntervalType(1,1),true), to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), None) AS from_csv(to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH)))#x] +Project [from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles), false) AS from_json({"a":"1 days"})#x, from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None) AS from_csv(1, 1)#x, to_json(from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1 days"}))#x, to_csv(from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None), Some(America/Los_Angeles)) AS to_csv(from_csv(1, 1))#x, to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)) AS to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH))#x, from_csv(StructField(a,YearMonthIntervalType(0,0),true), StructField(b,YearMonthIntervalType(1,1),true), to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), None) AS from_csv(to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH)))#x] +- OneRowRelation @@ -2085,7 +2119,7 @@ SELECT to_json(map('a', interval 100 day 130 minute)), from_json(to_json(map('a', interval 100 day 130 minute)), 'a interval day to minute') -- !query analysis -Project [from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE))#x, from_json(StructField(a,DayTimeIntervalType(0,2),true), to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS from_json(to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE)))#x] +Project [from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE))#x, from_json(StructField(a,DayTimeIntervalType(0,2),true), to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)), Some(America/Los_Angeles), false) AS from_json(to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE)))#x] +- OneRowRelation @@ -2096,7 +2130,7 @@ SELECT to_json(map('a', interval 32 year 10 month)), from_json(to_json(map('a', interval 32 year 10 month)), 'a interval year to month') -- !query analysis -Project [from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '32-10' YEAR TO MONTH))#x, from_json(StructField(a,YearMonthIntervalType(0,1),true), to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS from_json(to_json(map(a, INTERVAL '32-10' YEAR TO MONTH)))#x] +Project [from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '32-10' YEAR TO MONTH))#x, from_json(StructField(a,YearMonthIntervalType(0,1),true), to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), false) AS from_json(to_json(map(a, INTERVAL '32-10' YEAR TO MONTH)))#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/parse-schema-string.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/parse-schema-string.sql.out index 45fc3bd03a782..ae8e47ed3665c 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/parse-schema-string.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/parse-schema-string.sql.out @@ -16,12 +16,12 @@ Project [from_csv(StructField(cube,IntegerType,true), 1, Some(America/Los_Angele -- !query select from_json('{"create":1}', 'create INT') -- !query analysis -Project [from_json(StructField(create,IntegerType,true), {"create":1}, Some(America/Los_Angeles)) AS from_json({"create":1})#x] +Project [from_json(StructField(create,IntegerType,true), {"create":1}, Some(America/Los_Angeles), false) AS from_json({"create":1})#x] +- OneRowRelation -- !query select from_json('{"cube":1}', 'cube INT') -- !query analysis -Project [from_json(StructField(cube,IntegerType,true), {"cube":1}, Some(America/Los_Angeles)) AS from_json({"cube":1})#x] +Project [from_json(StructField(cube,IntegerType,true), {"cube":1}, Some(America/Los_Angeles), false) AS from_json({"cube":1})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/timestamp.sql.out index bf34490d657e3..560974d28c545 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/timestamp.sql.out @@ -730,7 +730,7 @@ Project [unix_timestamp(22 05 2020 Friday, dd MM yyyy EEEEE, Some(America/Los_An -- !query select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"t":"26/October/2015"})#x] +Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"t":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out index fb331089d7545..4db56d6c70561 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out @@ -194,25 +194,8 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select sort_array(array('b', 'd'), cast(NULL as boolean)) -- !query analysis -org.apache.spark.sql.catalyst.ExtendedAnalysisException -{ - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - "sqlState" : "42K09", - "messageParameters" : { - "inputSql" : "\"CAST(NULL AS BOOLEAN)\"", - "inputType" : "\"BOOLEAN\"", - "paramIndex" : "second", - "requiredType" : "\"BOOLEAN\"", - "sqlExpr" : "\"sort_array(array(b, d), CAST(NULL AS BOOLEAN))\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 57, - "fragment" : "sort_array(array('b', 'd'), cast(NULL as boolean))" - } ] -} +Project [sort_array(array(b, d), cast(null as boolean)) AS sort_array(array(b, d), CAST(NULL AS BOOLEAN))#x] ++- OneRowRelation -- !query diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out index 14ac67eb93a32..eed7fa73ab698 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out @@ -436,6 +436,30 @@ Project [str_to_map(collate(text#x, utf8_binary), collate(pairDelim#x, utf8_bina +- Relation spark_catalog.default.t4[text#x,pairDelim#x,keyValueDelim#x] parquet +-- !query +select str_to_map(text collate unicode_ai, pairDelim collate unicode_ai, keyValueDelim collate unicode_ai) from t4 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(text, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"str_to_map(collate(text, unicode_ai), collate(pairDelim, unicode_ai), collate(keyValueDelim, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 106, + "fragment" : "str_to_map(text collate unicode_ai, pairDelim collate unicode_ai, keyValueDelim collate unicode_ai)" + } ] +} + + -- !query drop table t4 -- !query analysis @@ -444,21 +468,227 @@ DropTable false, false -- !query -create table t5(str string collate utf8_binary, delimiter string collate utf8_lcase, partNum int) using parquet +create table t5(s string, utf8_binary string collate utf8_binary, utf8_lcase string collate utf8_lcase) using parquet -- !query analysis CreateDataSourceTableCommand `spark_catalog`.`default`.`t5`, false -- !query -insert into t5 values('11AB12AB13', 'AB', 2) +insert into t5 values ('Spark', 'Spark', 'SQL') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase] ++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x] + +- LocalRelation [col1#x, col2#x, col3#x] + + +-- !query +insert into t5 values ('aaAaAAaA', 'aaAaAAaA', 'aaAaAAaA') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase] ++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x] + +- LocalRelation [col1#x, col2#x, col3#x] + + +-- !query +insert into t5 values ('aaAaAAaA', 'aaAaAAaA', 'aaAaaAaA') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase] ++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x] + +- LocalRelation [col1#x, col2#x, col3#x] + + +-- !query +insert into t5 values ('aaAaAAaA', 'aaAaAAaA', 'aaAaaAaAaaAaaAaAaaAaaAaA') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase] ++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x] + +- LocalRelation [col1#x, col2#x, col3#x] + + +-- !query +insert into t5 values ('bbAbaAbA', 'bbAbAAbA', 'a') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase] ++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x] + +- LocalRelation [col1#x, col2#x, col3#x] + + +-- !query +insert into t5 values ('İo', 'İo', 'İo') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase] ++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x] + +- LocalRelation [col1#x, col2#x, col3#x] + + +-- !query +insert into t5 values ('İo', 'İo', 'i̇o') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase] ++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x] + +- LocalRelation [col1#x, col2#x, col3#x] + + +-- !query +insert into t5 values ('efd2', 'efd2', 'efd2') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase] ++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x] + +- LocalRelation [col1#x, col2#x, col3#x] + + +-- !query +insert into t5 values ('Hello, world! Nice day.', 'Hello, world! Nice day.', 'Hello, world! Nice day.') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase] ++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x] + +- LocalRelation [col1#x, col2#x, col3#x] + + +-- !query +insert into t5 values ('Something else. Nothing here.', 'Something else. Nothing here.', 'Something else. Nothing here.') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase] ++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x] + +- LocalRelation [col1#x, col2#x, col3#x] + + +-- !query +insert into t5 values ('kitten', 'kitten', 'sitTing') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase] ++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x] + +- LocalRelation [col1#x, col2#x, col3#x] + + +-- !query +insert into t5 values ('abc', 'abc', 'abc') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase] ++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x] + +- LocalRelation [col1#x, col2#x, col3#x] + + +-- !query +insert into t5 values ('abcdcba', 'abcdcba', 'aBcDCbA') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase] ++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x] + +- LocalRelation [col1#x, col2#x, col3#x] + + +-- !query +create table t6(ascii long) using parquet +-- !query analysis +CreateDataSourceTableCommand `spark_catalog`.`default`.`t6`, false + + +-- !query +insert into t6 values (97) +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t6, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t6], Append, `spark_catalog`.`default`.`t6`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t6), [ascii] ++- Project [cast(col1#x as bigint) AS ascii#xL] + +- LocalRelation [col1#x] + + +-- !query +insert into t6 values (66) +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t6, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t6], Append, `spark_catalog`.`default`.`t6`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t6), [ascii] ++- Project [cast(col1#x as bigint) AS ascii#xL] + +- LocalRelation [col1#x] + + +-- !query +create table t7(ascii double) using parquet +-- !query analysis +CreateDataSourceTableCommand `spark_catalog`.`default`.`t7`, false + + +-- !query +insert into t7 values (97.52143) +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t7, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t7], Append, `spark_catalog`.`default`.`t7`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t7), [ascii] ++- Project [cast(col1#x as double) AS ascii#x] + +- LocalRelation [col1#x] + + +-- !query +insert into t7 values (66.421) -- !query analysis -InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [str, delimiter, partNum] -+- Project [cast(col1#x as string) AS str#x, cast(col2#x as string collate UTF8_LCASE) AS delimiter#x, cast(col3#x as int) AS partNum#x] +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t7, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t7], Append, `spark_catalog`.`default`.`t7`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t7), [ascii] ++- Project [cast(col1#x as double) AS ascii#x] + +- LocalRelation [col1#x] + + +-- !query +create table t8(format string collate utf8_binary, utf8_binary string collate utf8_binary, utf8_lcase string collate utf8_lcase) using parquet +-- !query analysis +CreateDataSourceTableCommand `spark_catalog`.`default`.`t8`, false + + +-- !query +insert into t8 values ('%s%s', 'abCdE', 'abCdE') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t8, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t8], Append, `spark_catalog`.`default`.`t8`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t8), [format, utf8_binary, utf8_lcase] ++- Project [cast(col1#x as string) AS format#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x] +- LocalRelation [col1#x, col2#x, col3#x] -- !query -select split_part(str, delimiter, partNum) from t5 +create table t9(num long) using parquet +-- !query analysis +CreateDataSourceTableCommand `spark_catalog`.`default`.`t9`, false + + +-- !query +insert into t9 values (97) +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t9, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t9], Append, `spark_catalog`.`default`.`t9`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t9), [num] ++- Project [cast(col1#x as bigint) AS num#xL] + +- LocalRelation [col1#x] + + +-- !query +insert into t9 values (66) +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t9, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t9], Append, `spark_catalog`.`default`.`t9`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t9), [num] ++- Project [cast(col1#x as bigint) AS num#xL] + +- LocalRelation [col1#x] + + +-- !query +create table t10(utf8_binary string collate utf8_binary, utf8_lcase string collate utf8_lcase) using parquet +-- !query analysis +CreateDataSourceTableCommand `spark_catalog`.`default`.`t10`, false + + +-- !query +insert into t10 values ('aaAaAAaA', 'aaAaaAaA') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t10, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t10], Append, `spark_catalog`.`default`.`t10`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t10), [utf8_binary, utf8_lcase] ++- Project [cast(col1#x as string) AS utf8_binary#x, cast(col2#x as string collate UTF8_LCASE) AS utf8_lcase#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +insert into t10 values ('efd2', 'efd2') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t10, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t10], Append, `spark_catalog`.`default`.`t10`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t10), [utf8_binary, utf8_lcase] ++- Project [cast(col1#x as string) AS utf8_binary#x, cast(col2#x as string collate UTF8_LCASE) AS utf8_lcase#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select concat_ws(' ', utf8_lcase, utf8_lcase) from t5 +-- !query analysis +Project [concat_ws(cast( as string collate UTF8_LCASE), utf8_lcase#x, utf8_lcase#x) AS concat_ws( , utf8_lcase, utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select concat_ws(' ', utf8_binary, utf8_lcase) from t5 -- !query analysis org.apache.spark.sql.AnalysisException { @@ -468,7 +698,7 @@ org.apache.spark.sql.AnalysisException -- !query -select split_part(str collate utf8_binary, delimiter collate utf8_lcase, partNum) from t5 +select concat_ws(' ' collate utf8_binary, utf8_binary, 'SQL' collate utf8_lcase) from t5 -- !query analysis org.apache.spark.sql.AnalysisException { @@ -481,36 +711,39 @@ org.apache.spark.sql.AnalysisException -- !query -select split_part(str collate utf8_binary, delimiter collate utf8_binary, partNum) from t5 +select concat_ws(' ' collate utf8_lcase, utf8_binary, 'SQL' collate utf8_lcase) from t5 -- !query analysis -Project [split_part(collate(str#x, utf8_binary), collate(delimiter#x, utf8_binary), partNum#x) AS split_part(collate(str, utf8_binary), collate(delimiter, utf8_binary), partNum)#x] +Project [concat_ws(collate( , utf8_lcase), cast(utf8_binary#x as string collate UTF8_LCASE), collate(SQL, utf8_lcase)) AS concat_ws(collate( , utf8_lcase), utf8_binary, collate(SQL, utf8_lcase))#x] +- SubqueryAlias spark_catalog.default.t5 - +- Relation spark_catalog.default.t5[str#x,delimiter#x,partNum#x] parquet + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet -- !query -drop table t5 +select concat_ws(',', utf8_lcase, 'word'), concat_ws(',', utf8_binary, 'word') from t5 -- !query analysis -DropTable false, false -+- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t5 +Project [concat_ws(cast(, as string collate UTF8_LCASE), utf8_lcase#x, cast(word as string collate UTF8_LCASE)) AS concat_ws(,, utf8_lcase, word)#x, concat_ws(,, utf8_binary#x, word) AS concat_ws(,, utf8_binary, word)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet -- !query -create table t6 (utf8_binary string collate utf8_binary, utf8_lcase string collate utf8_lcase, threshold int) using parquet +select concat_ws(',', utf8_lcase, 'word' collate utf8_binary), concat_ws(',', utf8_binary, 'word' collate utf8_lcase) from t5 -- !query analysis -CreateDataSourceTableCommand `spark_catalog`.`default`.`t6`, false +Project [concat_ws(,, cast(utf8_lcase#x as string), collate(word, utf8_binary)) AS concat_ws(,, utf8_lcase, collate(word, utf8_binary))#x, concat_ws(cast(, as string collate UTF8_LCASE), cast(utf8_binary#x as string collate UTF8_LCASE), collate(word, utf8_lcase)) AS concat_ws(,, utf8_binary, collate(word, utf8_lcase))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet -- !query -insert into t6 values('kitten', 'sitting', 2) +select elt(2, s, utf8_binary) from t5 -- !query analysis -InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t6, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t6], Append, `spark_catalog`.`default`.`t6`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t6), [utf8_binary, utf8_lcase, threshold] -+- Project [cast(col1#x as string) AS utf8_binary#x, cast(col2#x as string collate UTF8_LCASE) AS utf8_lcase#x, cast(col3#x as int) AS threshold#x] - +- LocalRelation [col1#x, col2#x, col3#x] +Project [elt(2, s#x, utf8_binary#x, false) AS elt(2, s, utf8_binary)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet -- !query -select levenshtein(utf8_binary, utf8_lcase) from t6 +select elt(2, utf8_binary, utf8_lcase, s) from t5 -- !query analysis org.apache.spark.sql.AnalysisException { @@ -520,7 +753,7 @@ org.apache.spark.sql.AnalysisException -- !query -select levenshtein(utf8_binary collate utf8_binary, utf8_lcase collate utf8_lcase) from t6 +select elt(1, utf8_binary collate utf8_binary, utf8_lcase collate utf8_lcase) from t5 -- !query analysis org.apache.spark.sql.AnalysisException { @@ -533,15 +766,39 @@ org.apache.spark.sql.AnalysisException -- !query -select levenshtein(utf8_binary collate utf8_binary, utf8_lcase collate utf8_binary) from t6 +select elt(1, utf8_binary collate utf8_binary, utf8_lcase collate utf8_binary) from t5 -- !query analysis -Project [levenshtein(collate(utf8_binary#x, utf8_binary), collate(utf8_lcase#x, utf8_binary), None) AS levenshtein(collate(utf8_binary, utf8_binary), collate(utf8_lcase, utf8_binary))#x] -+- SubqueryAlias spark_catalog.default.t6 - +- Relation spark_catalog.default.t6[utf8_binary#x,utf8_lcase#x,threshold#x] parquet +Project [elt(1, collate(utf8_binary#x, utf8_binary), collate(utf8_lcase#x, utf8_binary), false) AS elt(1, collate(utf8_binary, utf8_binary), collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select elt(1, utf8_binary collate utf8_binary, utf8_lcase) from t5 +-- !query analysis +Project [elt(1, collate(utf8_binary#x, utf8_binary), cast(utf8_lcase#x as string), false) AS elt(1, collate(utf8_binary, utf8_binary), utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select elt(1, utf8_binary, 'word'), elt(1, utf8_lcase, 'word') from t5 +-- !query analysis +Project [elt(1, utf8_binary#x, word, false) AS elt(1, utf8_binary, word)#x, elt(1, utf8_lcase#x, cast(word as string collate UTF8_LCASE), false) AS elt(1, utf8_lcase, word)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet -- !query -select levenshtein(utf8_binary, utf8_lcase, threshold) from t6 +select elt(1, utf8_binary, 'word' collate utf8_lcase), elt(1, utf8_lcase, 'word' collate utf8_binary) from t5 +-- !query analysis +Project [elt(1, cast(utf8_binary#x as string collate UTF8_LCASE), collate(word, utf8_lcase), false) AS elt(1, utf8_binary, collate(word, utf8_lcase))#x, elt(1, cast(utf8_lcase#x as string), collate(word, utf8_binary), false) AS elt(1, utf8_lcase, collate(word, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select split_part(utf8_binary, utf8_lcase, 3) from t5 -- !query analysis org.apache.spark.sql.AnalysisException { @@ -551,7 +808,15 @@ org.apache.spark.sql.AnalysisException -- !query -select levenshtein(utf8_binary collate utf8_binary, utf8_lcase collate utf8_lcase, threshold) from t6 +select split_part(s, utf8_binary, 1) from t5 +-- !query analysis +Project [split_part(s#x, utf8_binary#x, 1) AS split_part(s, utf8_binary, 1)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select split_part(utf8_binary collate utf8_binary, s collate utf8_lcase, 1) from t5 -- !query analysis org.apache.spark.sql.AnalysisException { @@ -564,15 +829,1857 @@ org.apache.spark.sql.AnalysisException -- !query -select levenshtein(utf8_binary collate utf8_binary, utf8_lcase collate utf8_binary, threshold) from t6 +select split_part(utf8_binary, utf8_lcase collate utf8_binary, 2) from t5 -- !query analysis -Project [levenshtein(collate(utf8_binary#x, utf8_binary), collate(utf8_lcase#x, utf8_binary), Some(threshold#x)) AS levenshtein(collate(utf8_binary, utf8_binary), collate(utf8_lcase, utf8_binary), threshold)#x] -+- SubqueryAlias spark_catalog.default.t6 - +- Relation spark_catalog.default.t6[utf8_binary#x,utf8_lcase#x,threshold#x] parquet +Project [split_part(utf8_binary#x, collate(utf8_lcase#x, utf8_binary), 2) AS split_part(utf8_binary, collate(utf8_lcase, utf8_binary), 2)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet -- !query -drop table t6 +select split_part(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 2) from t5 -- !query analysis -DropTable false, false -+- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t6 +Project [split_part(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase), 2) AS split_part(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase), 2)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"split_part(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 83, + "fragment" : "split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2)" + } ] +} + + +-- !query +select split_part(utf8_binary, 'a', 3), split_part(utf8_lcase, 'a', 3) from t5 +-- !query analysis +Project [split_part(utf8_binary#x, a, 3) AS split_part(utf8_binary, a, 3)#x, split_part(utf8_lcase#x, cast(a as string collate UTF8_LCASE), 3) AS split_part(utf8_lcase, a, 3)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select split_part(utf8_binary, 'a' collate utf8_lcase, 3), split_part(utf8_lcase, 'a' collate utf8_binary, 3) from t5 +-- !query analysis +Project [split_part(cast(utf8_binary#x as string collate UTF8_LCASE), collate(a, utf8_lcase), 3) AS split_part(utf8_binary, collate(a, utf8_lcase), 3)#x, split_part(cast(utf8_lcase#x as string), collate(a, utf8_binary), 3) AS split_part(utf8_lcase, collate(a, utf8_binary), 3)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select contains(utf8_binary, utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select contains(s, utf8_binary) from t5 +-- !query analysis +Project [Contains(s#x, utf8_binary#x) AS contains(s, utf8_binary)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select contains(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select contains(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [Contains(utf8_binary#x, collate(utf8_lcase#x, utf8_binary)) AS contains(utf8_binary, collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select contains(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query analysis +Project [Contains(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase)) AS contains(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"contains(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 78, + "fragment" : "contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + +-- !query +select contains(utf8_binary, 'a'), contains(utf8_lcase, 'a') from t5 +-- !query analysis +Project [Contains(utf8_binary#x, a) AS contains(utf8_binary, a)#x, Contains(utf8_lcase#x, cast(a as string collate UTF8_LCASE)) AS contains(utf8_lcase, a)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select contains(utf8_binary, 'AaAA' collate utf8_lcase), contains(utf8_lcase, 'AAa' collate utf8_binary) from t5 +-- !query analysis +Project [Contains(cast(utf8_binary#x as string collate UTF8_LCASE), collate(AaAA, utf8_lcase)) AS contains(utf8_binary, collate(AaAA, utf8_lcase))#x, Contains(cast(utf8_lcase#x as string), collate(AAa, utf8_binary)) AS contains(utf8_lcase, collate(AAa, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select substring_index(utf8_binary, utf8_lcase, 2) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select substring_index(s, utf8_binary,1) from t5 +-- !query analysis +Project [substring_index(s#x, utf8_binary#x, 1) AS substring_index(s, utf8_binary, 1)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select substring_index(utf8_binary collate utf8_binary, s collate utf8_lcase, 3) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select substring_index(utf8_binary, utf8_lcase collate utf8_binary, 2) from t5 +-- !query analysis +Project [substring_index(utf8_binary#x, collate(utf8_lcase#x, utf8_binary), 2) AS substring_index(utf8_binary, collate(utf8_lcase, utf8_binary), 2)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select substring_index(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 2) from t5 +-- !query analysis +Project [substring_index(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase), 2) AS substring_index(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase), 2)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"substring_index(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 88, + "fragment" : "substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2)" + } ] +} + + +-- !query +select substring_index(utf8_binary, 'a', 2), substring_index(utf8_lcase, 'a', 2) from t5 +-- !query analysis +Project [substring_index(utf8_binary#x, a, 2) AS substring_index(utf8_binary, a, 2)#x, substring_index(utf8_lcase#x, cast(a as string collate UTF8_LCASE), 2) AS substring_index(utf8_lcase, a, 2)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select substring_index(utf8_binary, 'AaAA' collate utf8_lcase, 2), substring_index(utf8_lcase, 'AAa' collate utf8_binary, 2) from t5 +-- !query analysis +Project [substring_index(cast(utf8_binary#x as string collate UTF8_LCASE), collate(AaAA, utf8_lcase), 2) AS substring_index(utf8_binary, collate(AaAA, utf8_lcase), 2)#x, substring_index(cast(utf8_lcase#x as string), collate(AAa, utf8_binary), 2) AS substring_index(utf8_lcase, collate(AAa, utf8_binary), 2)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select instr(utf8_binary, utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select instr(s, utf8_binary) from t5 +-- !query analysis +Project [instr(s#x, utf8_binary#x) AS instr(s, utf8_binary)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select instr(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select instr(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [instr(utf8_binary#x, collate(utf8_lcase#x, utf8_binary)) AS instr(utf8_binary, collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select instr(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query analysis +Project [instr(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase)) AS instr(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"instr(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + +-- !query +select instr(utf8_binary, 'a'), instr(utf8_lcase, 'a') from t5 +-- !query analysis +Project [instr(utf8_binary#x, a) AS instr(utf8_binary, a)#x, instr(utf8_lcase#x, cast(a as string collate UTF8_LCASE)) AS instr(utf8_lcase, a)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select instr(utf8_binary, 'AaAA' collate utf8_lcase), instr(utf8_lcase, 'AAa' collate utf8_binary) from t5 +-- !query analysis +Project [instr(cast(utf8_binary#x as string collate UTF8_LCASE), collate(AaAA, utf8_lcase)) AS instr(utf8_binary, collate(AaAA, utf8_lcase))#x, instr(cast(utf8_lcase#x as string), collate(AAa, utf8_binary)) AS instr(utf8_lcase, collate(AAa, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select find_in_set(utf8_binary, utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select find_in_set(s, utf8_binary) from t5 +-- !query analysis +Project [find_in_set(s#x, utf8_binary#x) AS find_in_set(s, utf8_binary)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select find_in_set(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select find_in_set(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [find_in_set(utf8_binary#x, collate(utf8_lcase#x, utf8_binary)) AS find_in_set(utf8_binary, collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select find_in_set(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query analysis +Project [find_in_set(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase)) AS find_in_set(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select find_in_set(utf8_binary, 'aaAaaAaA,i̇o'), find_in_set(utf8_lcase, 'aaAaaAaA,i̇o') from t5 +-- !query analysis +Project [find_in_set(utf8_binary#x, aaAaaAaA,i̇o) AS find_in_set(utf8_binary, aaAaaAaA,i̇o)#x, find_in_set(utf8_lcase#x, cast(aaAaaAaA,i̇o as string collate UTF8_LCASE)) AS find_in_set(utf8_lcase, aaAaaAaA,i̇o)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select find_in_set(utf8_binary, 'aaAaaAaA,i̇o' collate utf8_lcase), find_in_set(utf8_lcase, 'aaAaaAaA,i̇o' collate utf8_binary) from t5 +-- !query analysis +Project [find_in_set(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aaAaaAaA,i̇o, utf8_lcase)) AS find_in_set(utf8_binary, collate(aaAaaAaA,i̇o, utf8_lcase))#x, find_in_set(cast(utf8_lcase#x as string), collate(aaAaaAaA,i̇o, utf8_binary)) AS find_in_set(utf8_lcase, collate(aaAaaAaA,i̇o, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select startswith(utf8_binary, utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select startswith(s, utf8_binary) from t5 +-- !query analysis +Project [StartsWith(s#x, utf8_binary#x) AS startswith(s, utf8_binary)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select startswith(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select startswith(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [StartsWith(utf8_binary#x, collate(utf8_lcase#x, utf8_binary)) AS startswith(utf8_binary, collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select startswith(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query analysis +Project [StartsWith(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase)) AS startswith(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"startswith(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 80, + "fragment" : "startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + +-- !query +select startswith(utf8_binary, 'aaAaaAaA'), startswith(utf8_lcase, 'aaAaaAaA') from t5 +-- !query analysis +Project [StartsWith(utf8_binary#x, aaAaaAaA) AS startswith(utf8_binary, aaAaaAaA)#x, StartsWith(utf8_lcase#x, cast(aaAaaAaA as string collate UTF8_LCASE)) AS startswith(utf8_lcase, aaAaaAaA)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select startswith(utf8_binary, 'aaAaaAaA' collate utf8_lcase), startswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5 +-- !query analysis +Project [StartsWith(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aaAaaAaA, utf8_lcase)) AS startswith(utf8_binary, collate(aaAaaAaA, utf8_lcase))#x, StartsWith(cast(utf8_lcase#x as string), collate(aaAaaAaA, utf8_binary)) AS startswith(utf8_lcase, collate(aaAaaAaA, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select translate(utf8_lcase, utf8_lcase, '12345') from t5 +-- !query analysis +Project [translate(utf8_lcase#x, utf8_lcase#x, cast(12345 as string collate UTF8_LCASE)) AS translate(utf8_lcase, utf8_lcase, 12345)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select translate(utf8_binary, utf8_lcase, '12345') from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select translate(utf8_binary, 'aBc' collate utf8_lcase, '12345' collate utf8_binary) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string collate UTF8_LCASE`, `string`" + } +} + + +-- !query +select translate(utf8_binary, 'SQL' collate utf8_lcase, '12345' collate utf8_lcase) from t5 +-- !query analysis +Project [translate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(SQL, utf8_lcase), collate(12345, utf8_lcase)) AS translate(utf8_binary, collate(SQL, utf8_lcase), collate(12345, utf8_lcase))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"utf8_binary\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"translate(utf8_binary, collate(SQL, unicode_ai), collate(12345, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 83, + "fragment" : "translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai)" + } ] +} + + +-- !query +select translate(utf8_lcase, 'aaAaaAaA', '12345'), translate(utf8_binary, 'aaAaaAaA', '12345') from t5 +-- !query analysis +Project [translate(utf8_lcase#x, cast(aaAaaAaA as string collate UTF8_LCASE), cast(12345 as string collate UTF8_LCASE)) AS translate(utf8_lcase, aaAaaAaA, 12345)#x, translate(utf8_binary#x, aaAaaAaA, 12345) AS translate(utf8_binary, aaAaaAaA, 12345)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select translate(utf8_lcase, 'aBc' collate utf8_binary, '12345'), translate(utf8_binary, 'aBc' collate utf8_lcase, '12345') from t5 +-- !query analysis +Project [translate(cast(utf8_lcase#x as string), collate(aBc, utf8_binary), 12345) AS translate(utf8_lcase, collate(aBc, utf8_binary), 12345)#x, translate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aBc, utf8_lcase), cast(12345 as string collate UTF8_LCASE)) AS translate(utf8_binary, collate(aBc, utf8_lcase), 12345)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select replace(utf8_binary, utf8_lcase, 'abc') from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select replace(s, utf8_binary, 'abc') from t5 +-- !query analysis +Project [replace(s#x, utf8_binary#x, abc) AS replace(s, utf8_binary, abc)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select replace(utf8_binary collate utf8_binary, s collate utf8_lcase, 'abc') from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select replace(utf8_binary, utf8_lcase collate utf8_binary, 'abc') from t5 +-- !query analysis +Project [replace(utf8_binary#x, collate(utf8_lcase#x, utf8_binary), abc) AS replace(utf8_binary, collate(utf8_lcase, utf8_binary), abc)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select replace(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 'abc') from t5 +-- !query analysis +Project [replace(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase), cast(abc as string collate UTF8_LCASE)) AS replace(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase), abc)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc') from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"replace(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), abc)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 84, + "fragment" : "replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc')" + } ] +} + + +-- !query +select replace(utf8_binary, 'aaAaaAaA', 'abc'), replace(utf8_lcase, 'aaAaaAaA', 'abc') from t5 +-- !query analysis +Project [replace(utf8_binary#x, aaAaaAaA, abc) AS replace(utf8_binary, aaAaaAaA, abc)#x, replace(utf8_lcase#x, cast(aaAaaAaA as string collate UTF8_LCASE), cast(abc as string collate UTF8_LCASE)) AS replace(utf8_lcase, aaAaaAaA, abc)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select replace(utf8_binary, 'aaAaaAaA' collate utf8_lcase, 'abc'), replace(utf8_lcase, 'aaAaaAaA' collate utf8_binary, 'abc') from t5 +-- !query analysis +Project [replace(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aaAaaAaA, utf8_lcase), cast(abc as string collate UTF8_LCASE)) AS replace(utf8_binary, collate(aaAaaAaA, utf8_lcase), abc)#x, replace(cast(utf8_lcase#x as string), collate(aaAaaAaA, utf8_binary), abc) AS replace(utf8_lcase, collate(aaAaaAaA, utf8_binary), abc)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select endswith(utf8_binary, utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select endswith(s, utf8_binary) from t5 +-- !query analysis +Project [EndsWith(s#x, utf8_binary#x) AS endswith(s, utf8_binary)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select endswith(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select endswith(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [EndsWith(utf8_binary#x, collate(utf8_lcase#x, utf8_binary)) AS endswith(utf8_binary, collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select endswith(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query analysis +Project [EndsWith(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase)) AS endswith(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"endswith(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 78, + "fragment" : "endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + +-- !query +select endswith(utf8_binary, 'aaAaaAaA'), endswith(utf8_lcase, 'aaAaaAaA') from t5 +-- !query analysis +Project [EndsWith(utf8_binary#x, aaAaaAaA) AS endswith(utf8_binary, aaAaaAaA)#x, EndsWith(utf8_lcase#x, cast(aaAaaAaA as string collate UTF8_LCASE)) AS endswith(utf8_lcase, aaAaaAaA)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select endswith(utf8_binary, 'aaAaaAaA' collate utf8_lcase), endswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5 +-- !query analysis +Project [EndsWith(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aaAaaAaA, utf8_lcase)) AS endswith(utf8_binary, collate(aaAaaAaA, utf8_lcase))#x, EndsWith(cast(utf8_lcase#x as string), collate(aaAaaAaA, utf8_binary)) AS endswith(utf8_lcase, collate(aaAaaAaA, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select repeat(utf8_binary, 3), repeat(utf8_lcase, 2) from t5 +-- !query analysis +Project [repeat(utf8_binary#x, 3) AS repeat(utf8_binary, 3)#x, repeat(utf8_lcase#x, 2) AS repeat(utf8_lcase, 2)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select repeat(utf8_binary collate utf8_lcase, 3), repeat(utf8_lcase collate utf8_binary, 2) from t5 +-- !query analysis +Project [repeat(collate(utf8_binary#x, utf8_lcase), 3) AS repeat(collate(utf8_binary, utf8_lcase), 3)#x, repeat(collate(utf8_lcase#x, utf8_binary), 2) AS repeat(collate(utf8_lcase, utf8_binary), 2)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select ascii(utf8_binary), ascii(utf8_lcase) from t5 +-- !query analysis +Project [ascii(utf8_binary#x) AS ascii(utf8_binary)#x, ascii(utf8_lcase#x) AS ascii(utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select ascii(utf8_binary collate utf8_lcase), ascii(utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [ascii(collate(utf8_binary#x, utf8_lcase)) AS ascii(collate(utf8_binary, utf8_lcase))#x, ascii(collate(utf8_lcase#x, utf8_binary)) AS ascii(collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select unbase64(utf8_binary), unbase64(utf8_lcase) from t10 +-- !query analysis +Project [unbase64(utf8_binary#x, false) AS unbase64(utf8_binary)#x, unbase64(utf8_lcase#x, false) AS unbase64(utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t10 + +- Relation spark_catalog.default.t10[utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select unbase64(utf8_binary collate utf8_lcase), unbase64(utf8_lcase collate utf8_binary) from t10 +-- !query analysis +Project [unbase64(collate(utf8_binary#x, utf8_lcase), false) AS unbase64(collate(utf8_binary, utf8_lcase))#x, unbase64(collate(utf8_lcase#x, utf8_binary), false) AS unbase64(collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t10 + +- Relation spark_catalog.default.t10[utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select chr(ascii) from t6 +-- !query analysis +Project [chr(ascii#xL) AS chr(ascii)#x] ++- SubqueryAlias spark_catalog.default.t6 + +- Relation spark_catalog.default.t6[ascii#xL] parquet + + +-- !query +select base64(utf8_binary), base64(utf8_lcase) from t5 +-- !query analysis +Project [base64(cast(utf8_binary#x as binary)) AS base64(utf8_binary)#x, base64(cast(utf8_lcase#x as binary)) AS base64(utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select base64(utf8_binary collate utf8_lcase), base64(utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [base64(cast(collate(utf8_binary#x, utf8_lcase) as binary)) AS base64(collate(utf8_binary, utf8_lcase))#x, base64(cast(collate(utf8_lcase#x, utf8_binary) as binary)) AS base64(collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select decode(encode(utf8_binary, 'utf-8'), 'utf-8'), decode(encode(utf8_lcase, 'utf-8'), 'utf-8') from t5 +-- !query analysis +Project [decode(encode(utf8_binary#x, utf-8), utf-8) AS decode(encode(utf8_binary, utf-8), utf-8)#x, decode(encode(utf8_lcase#x, utf-8), utf-8) AS decode(encode(utf8_lcase, utf-8), utf-8)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select decode(encode(utf8_binary collate utf8_lcase, 'utf-8'), 'utf-8'), decode(encode(utf8_lcase collate utf8_binary, 'utf-8'), 'utf-8') from t5 +-- !query analysis +Project [decode(encode(collate(utf8_binary#x, utf8_lcase), utf-8), utf-8) AS decode(encode(collate(utf8_binary, utf8_lcase), utf-8), utf-8)#x, decode(encode(collate(utf8_lcase#x, utf8_binary), utf-8), utf-8) AS decode(encode(collate(utf8_lcase, utf8_binary), utf-8), utf-8)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select format_number(ascii, '###.###') from t7 +-- !query analysis +Project [format_number(ascii#x, ###.###) AS format_number(ascii, ###.###)#x] ++- SubqueryAlias spark_catalog.default.t7 + +- Relation spark_catalog.default.t7[ascii#x] parquet + + +-- !query +select format_number(ascii, '###.###' collate utf8_lcase) from t7 +-- !query analysis +Project [format_number(ascii#x, collate(###.###, utf8_lcase)) AS format_number(ascii, collate(###.###, utf8_lcase))#x] ++- SubqueryAlias spark_catalog.default.t7 + +- Relation spark_catalog.default.t7[ascii#x] parquet + + +-- !query +select encode(utf8_binary, 'utf-8'), encode(utf8_lcase, 'utf-8') from t5 +-- !query analysis +Project [encode(utf8_binary#x, utf-8) AS encode(utf8_binary, utf-8)#x, encode(utf8_lcase#x, utf-8) AS encode(utf8_lcase, utf-8)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select encode(utf8_binary collate utf8_lcase, 'utf-8'), encode(utf8_lcase collate utf8_binary, 'utf-8') from t5 +-- !query analysis +Project [encode(collate(utf8_binary#x, utf8_lcase), utf-8) AS encode(collate(utf8_binary, utf8_lcase), utf-8)#x, encode(collate(utf8_lcase#x, utf8_binary), utf-8) AS encode(collate(utf8_lcase, utf8_binary), utf-8)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select to_binary(utf8_binary, 'utf-8'), to_binary(utf8_lcase, 'utf-8') from t5 +-- !query analysis +Project [to_binary(utf8_binary#x, Some(utf-8), false) AS to_binary(utf8_binary, utf-8)#x, to_binary(utf8_lcase#x, Some(utf-8), false) AS to_binary(utf8_lcase, utf-8)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select to_binary(utf8_binary collate utf8_lcase, 'utf-8'), to_binary(utf8_lcase collate utf8_binary, 'utf-8') from t5 +-- !query analysis +Project [to_binary(collate(utf8_binary#x, utf8_lcase), Some(utf-8), false) AS to_binary(collate(utf8_binary, utf8_lcase), utf-8)#x, to_binary(collate(utf8_lcase#x, utf8_binary), Some(utf-8), false) AS to_binary(collate(utf8_lcase, utf8_binary), utf-8)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select sentences(utf8_binary), sentences(utf8_lcase) from t5 +-- !query analysis +Project [sentences(utf8_binary#x, , ) AS sentences(utf8_binary, , )#x, sentences(utf8_lcase#x, , ) AS sentences(utf8_lcase, , )#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select sentences(utf8_binary collate utf8_lcase), sentences(utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [sentences(collate(utf8_binary#x, utf8_lcase), , ) AS sentences(collate(utf8_binary, utf8_lcase), , )#x, sentences(collate(utf8_lcase#x, utf8_binary), , ) AS sentences(collate(utf8_lcase, utf8_binary), , )#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select upper(utf8_binary), upper(utf8_lcase) from t5 +-- !query analysis +Project [upper(utf8_binary#x) AS upper(utf8_binary)#x, upper(utf8_lcase#x) AS upper(utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select upper(utf8_binary collate utf8_lcase), upper(utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [upper(collate(utf8_binary#x, utf8_lcase)) AS upper(collate(utf8_binary, utf8_lcase))#x, upper(collate(utf8_lcase#x, utf8_binary)) AS upper(collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select lower(utf8_binary), lower(utf8_lcase) from t5 +-- !query analysis +Project [lower(utf8_binary#x) AS lower(utf8_binary)#x, lower(utf8_lcase#x) AS lower(utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select lower(utf8_binary collate utf8_lcase), lower(utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [lower(collate(utf8_binary#x, utf8_lcase)) AS lower(collate(utf8_binary, utf8_lcase))#x, lower(collate(utf8_lcase#x, utf8_binary)) AS lower(collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select initcap(utf8_binary), initcap(utf8_lcase) from t5 +-- !query analysis +Project [initcap(utf8_binary#x) AS initcap(utf8_binary)#x, initcap(utf8_lcase#x) AS initcap(utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select initcap(utf8_binary collate utf8_lcase), initcap(utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [initcap(collate(utf8_binary#x, utf8_lcase)) AS initcap(collate(utf8_binary, utf8_lcase))#x, initcap(collate(utf8_lcase#x, utf8_binary)) AS initcap(collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select overlay(utf8_binary, utf8_lcase, 2) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select overlay(s, utf8_binary,1) from t5 +-- !query analysis +Project [overlay(s#x, utf8_binary#x, 1, -1) AS overlay(s, utf8_binary, 1, -1)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select overlay(utf8_binary collate utf8_binary, s collate utf8_lcase, 3) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select overlay(utf8_binary, utf8_lcase collate utf8_binary, 2) from t5 +-- !query analysis +Project [overlay(utf8_binary#x, collate(utf8_lcase#x, utf8_binary), 2, -1) AS overlay(utf8_binary, collate(utf8_lcase, utf8_binary), 2, -1)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select overlay(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 2) from t5 +-- !query analysis +Project [overlay(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase), 2, -1) AS overlay(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase), 2, -1)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select overlay(utf8_binary, 'a', 2), overlay(utf8_lcase, 'a', 2) from t5 +-- !query analysis +Project [overlay(utf8_binary#x, a, 2, -1) AS overlay(utf8_binary, a, 2, -1)#x, overlay(utf8_lcase#x, cast(a as string collate UTF8_LCASE), 2, -1) AS overlay(utf8_lcase, a, 2, -1)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select overlay(utf8_binary, 'AaAA' collate utf8_lcase, 2), overlay(utf8_lcase, 'AAa' collate utf8_binary, 2) from t5 +-- !query analysis +Project [overlay(cast(utf8_binary#x as string collate UTF8_LCASE), collate(AaAA, utf8_lcase), 2, -1) AS overlay(utf8_binary, collate(AaAA, utf8_lcase), 2, -1)#x, overlay(cast(utf8_lcase#x as string), collate(AAa, utf8_binary), 2, -1) AS overlay(utf8_lcase, collate(AAa, utf8_binary), 2, -1)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select format_string(format, utf8_binary, utf8_lcase) from t8 +-- !query analysis +Project [format_string(format#x, utf8_binary#x, utf8_lcase#x) AS format_string(format, utf8_binary, utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t8 + +- Relation spark_catalog.default.t8[format#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select format_string(format collate utf8_lcase, utf8_lcase, utf8_binary collate utf8_lcase, 3), format_string(format, utf8_lcase collate utf8_binary, utf8_binary) from t8 +-- !query analysis +Project [format_string(collate(format#x, utf8_lcase), utf8_lcase#x, collate(utf8_binary#x, utf8_lcase), 3) AS format_string(collate(format, utf8_lcase), utf8_lcase, collate(utf8_binary, utf8_lcase), 3)#x, format_string(format#x, collate(utf8_lcase#x, utf8_binary), utf8_binary#x) AS format_string(format, collate(utf8_lcase, utf8_binary), utf8_binary)#x] ++- SubqueryAlias spark_catalog.default.t8 + +- Relation spark_catalog.default.t8[format#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select format_string(format, utf8_binary, utf8_lcase) from t8 +-- !query analysis +Project [format_string(format#x, utf8_binary#x, utf8_lcase#x) AS format_string(format, utf8_binary, utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t8 + +- Relation spark_catalog.default.t8[format#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select soundex(utf8_binary), soundex(utf8_lcase) from t5 +-- !query analysis +Project [soundex(utf8_binary#x) AS soundex(utf8_binary)#x, soundex(utf8_lcase#x) AS soundex(utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select soundex(utf8_binary collate utf8_lcase), soundex(utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [soundex(collate(utf8_binary#x, utf8_lcase)) AS soundex(collate(utf8_binary, utf8_lcase))#x, soundex(collate(utf8_lcase#x, utf8_binary)) AS soundex(collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select length(utf8_binary), length(utf8_lcase) from t5 +-- !query analysis +Project [length(utf8_binary#x) AS length(utf8_binary)#x, length(utf8_lcase#x) AS length(utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select length(utf8_binary collate utf8_lcase), length(utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [length(collate(utf8_binary#x, utf8_lcase)) AS length(collate(utf8_binary, utf8_lcase))#x, length(collate(utf8_lcase#x, utf8_binary)) AS length(collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select bit_length(utf8_binary), bit_length(utf8_lcase) from t5 +-- !query analysis +Project [bit_length(utf8_binary#x) AS bit_length(utf8_binary)#x, bit_length(utf8_lcase#x) AS bit_length(utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select bit_length(utf8_binary collate utf8_lcase), bit_length(utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [bit_length(collate(utf8_binary#x, utf8_lcase)) AS bit_length(collate(utf8_binary, utf8_lcase))#x, bit_length(collate(utf8_lcase#x, utf8_binary)) AS bit_length(collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select octet_length(utf8_binary), octet_length(utf8_lcase) from t5 +-- !query analysis +Project [octet_length(utf8_binary#x) AS octet_length(utf8_binary)#x, octet_length(utf8_lcase#x) AS octet_length(utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select octet_length(utf8_binary collate utf8_lcase), octet_length(utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [octet_length(collate(utf8_binary#x, utf8_lcase)) AS octet_length(collate(utf8_binary, utf8_lcase))#x, octet_length(collate(utf8_lcase#x, utf8_binary)) AS octet_length(collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select luhn_check(num) from t9 +-- !query analysis +Project [luhn_check(cast(num#xL as string)) AS luhn_check(num)#x] ++- SubqueryAlias spark_catalog.default.t9 + +- Relation spark_catalog.default.t9[num#xL] parquet + + +-- !query +select levenshtein(utf8_binary, utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select levenshtein(s, utf8_binary) from t5 +-- !query analysis +Project [levenshtein(s#x, utf8_binary#x, None) AS levenshtein(s, utf8_binary)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select levenshtein(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select levenshtein(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [levenshtein(utf8_binary#x, collate(utf8_lcase#x, utf8_binary), None) AS levenshtein(utf8_binary, collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select levenshtein(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query analysis +Project [levenshtein(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase), None) AS levenshtein(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select levenshtein(utf8_binary, 'a'), levenshtein(utf8_lcase, 'a') from t5 +-- !query analysis +Project [levenshtein(utf8_binary#x, a, None) AS levenshtein(utf8_binary, a)#x, levenshtein(utf8_lcase#x, cast(a as string collate UTF8_LCASE), None) AS levenshtein(utf8_lcase, a)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select levenshtein(utf8_binary, 'AaAA' collate utf8_lcase, 3), levenshtein(utf8_lcase, 'AAa' collate utf8_binary, 4) from t5 +-- !query analysis +Project [levenshtein(cast(utf8_binary#x as string collate UTF8_LCASE), collate(AaAA, utf8_lcase), Some(3)) AS levenshtein(utf8_binary, collate(AaAA, utf8_lcase), 3)#x, levenshtein(cast(utf8_lcase#x as string), collate(AAa, utf8_binary), Some(4)) AS levenshtein(utf8_lcase, collate(AAa, utf8_binary), 4)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select is_valid_utf8(utf8_binary), is_valid_utf8(utf8_lcase) from t5 +-- !query analysis +Project [is_valid_utf8(utf8_binary#x) AS is_valid_utf8(utf8_binary)#x, is_valid_utf8(utf8_lcase#x) AS is_valid_utf8(utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select is_valid_utf8(utf8_binary collate utf8_lcase), is_valid_utf8(utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [is_valid_utf8(collate(utf8_binary#x, utf8_lcase)) AS is_valid_utf8(collate(utf8_binary, utf8_lcase))#x, is_valid_utf8(collate(utf8_lcase#x, utf8_binary)) AS is_valid_utf8(collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select make_valid_utf8(utf8_binary), make_valid_utf8(utf8_lcase) from t5 +-- !query analysis +Project [make_valid_utf8(utf8_binary#x) AS make_valid_utf8(utf8_binary)#x, make_valid_utf8(utf8_lcase#x) AS make_valid_utf8(utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select make_valid_utf8(utf8_binary collate utf8_lcase), make_valid_utf8(utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [make_valid_utf8(collate(utf8_binary#x, utf8_lcase)) AS make_valid_utf8(collate(utf8_binary, utf8_lcase))#x, make_valid_utf8(collate(utf8_lcase#x, utf8_binary)) AS make_valid_utf8(collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select validate_utf8(utf8_binary), validate_utf8(utf8_lcase) from t5 +-- !query analysis +Project [validate_utf8(utf8_binary#x) AS validate_utf8(utf8_binary)#x, validate_utf8(utf8_lcase#x) AS validate_utf8(utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select validate_utf8(utf8_binary collate utf8_lcase), validate_utf8(utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [validate_utf8(collate(utf8_binary#x, utf8_lcase)) AS validate_utf8(collate(utf8_binary, utf8_lcase))#x, validate_utf8(collate(utf8_lcase#x, utf8_binary)) AS validate_utf8(collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select try_validate_utf8(utf8_binary), try_validate_utf8(utf8_lcase) from t5 +-- !query analysis +Project [try_validate_utf8(utf8_binary#x) AS try_validate_utf8(utf8_binary)#x, try_validate_utf8(utf8_lcase#x) AS try_validate_utf8(utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select try_validate_utf8(utf8_binary collate utf8_lcase), try_validate_utf8(utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [try_validate_utf8(collate(utf8_binary#x, utf8_lcase)) AS try_validate_utf8(collate(utf8_binary, utf8_lcase))#x, try_validate_utf8(collate(utf8_lcase#x, utf8_binary)) AS try_validate_utf8(collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select substr(utf8_binary, 2, 2), substr(utf8_lcase, 2, 2) from t5 +-- !query analysis +Project [substr(utf8_binary#x, 2, 2) AS substr(utf8_binary, 2, 2)#x, substr(utf8_lcase#x, 2, 2) AS substr(utf8_lcase, 2, 2)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select substr(utf8_binary collate utf8_lcase, 2, 2), substr(utf8_lcase collate utf8_binary, 2, 2) from t5 +-- !query analysis +Project [substr(collate(utf8_binary#x, utf8_lcase), 2, 2) AS substr(collate(utf8_binary, utf8_lcase), 2, 2)#x, substr(collate(utf8_lcase#x, utf8_binary), 2, 2) AS substr(collate(utf8_lcase, utf8_binary), 2, 2)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select right(utf8_binary, 2), right(utf8_lcase, 2) from t5 +-- !query analysis +Project [right(utf8_binary#x, 2) AS right(utf8_binary, 2)#x, right(utf8_lcase#x, 2) AS right(utf8_lcase, 2)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select right(utf8_binary collate utf8_lcase, 2), right(utf8_lcase collate utf8_binary, 2) from t5 +-- !query analysis +Project [right(collate(utf8_binary#x, utf8_lcase), 2) AS right(collate(utf8_binary, utf8_lcase), 2)#x, right(collate(utf8_lcase#x, utf8_binary), 2) AS right(collate(utf8_lcase, utf8_binary), 2)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select left(utf8_binary, '2' collate utf8_lcase), left(utf8_lcase, 2) from t5 +-- !query analysis +Project [left(utf8_binary#x, cast(collate(2, utf8_lcase) as int)) AS left(utf8_binary, collate(2, utf8_lcase))#x, left(utf8_lcase#x, 2) AS left(utf8_lcase, 2)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select left(utf8_binary collate utf8_lcase, 2), left(utf8_lcase collate utf8_binary, 2) from t5 +-- !query analysis +Project [left(collate(utf8_binary#x, utf8_lcase), 2) AS left(collate(utf8_binary, utf8_lcase), 2)#x, left(collate(utf8_lcase#x, utf8_binary), 2) AS left(collate(utf8_lcase, utf8_binary), 2)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select rpad(utf8_binary, 8, utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select rpad(s, 8, utf8_binary) from t5 +-- !query analysis +Project [rpad(s#x, 8, utf8_binary#x) AS rpad(s, 8, utf8_binary)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select rpad(utf8_binary collate utf8_binary, 8, s collate utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select rpad(utf8_binary, 8, utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [rpad(utf8_binary#x, 8, collate(utf8_lcase#x, utf8_binary)) AS rpad(utf8_binary, 8, collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select rpad(utf8_binary collate utf8_lcase, 8, utf8_lcase collate utf8_lcase) from t5 +-- !query analysis +Project [rpad(collate(utf8_binary#x, utf8_lcase), 8, collate(utf8_lcase#x, utf8_lcase)) AS rpad(collate(utf8_binary, utf8_lcase), 8, collate(utf8_lcase, utf8_lcase))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select rpad(utf8_binary, 8, 'a'), rpad(utf8_lcase, 8, 'a') from t5 +-- !query analysis +Project [rpad(utf8_binary#x, 8, a) AS rpad(utf8_binary, 8, a)#x, rpad(utf8_lcase#x, 8, cast(a as string collate UTF8_LCASE)) AS rpad(utf8_lcase, 8, a)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select rpad(utf8_binary, 8, 'AaAA' collate utf8_lcase), rpad(utf8_lcase, 8, 'AAa' collate utf8_binary) from t5 +-- !query analysis +Project [rpad(cast(utf8_binary#x as string collate UTF8_LCASE), 8, collate(AaAA, utf8_lcase)) AS rpad(utf8_binary, 8, collate(AaAA, utf8_lcase))#x, rpad(cast(utf8_lcase#x as string), 8, collate(AAa, utf8_binary)) AS rpad(utf8_lcase, 8, collate(AAa, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select lpad(utf8_binary, 8, utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select lpad(s, 8, utf8_binary) from t5 +-- !query analysis +Project [lpad(s#x, 8, utf8_binary#x) AS lpad(s, 8, utf8_binary)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select lpad(utf8_binary collate utf8_binary, 8, s collate utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select lpad(utf8_binary, 8, utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [lpad(utf8_binary#x, 8, collate(utf8_lcase#x, utf8_binary)) AS lpad(utf8_binary, 8, collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select lpad(utf8_binary collate utf8_lcase, 8, utf8_lcase collate utf8_lcase) from t5 +-- !query analysis +Project [lpad(collate(utf8_binary#x, utf8_lcase), 8, collate(utf8_lcase#x, utf8_lcase)) AS lpad(collate(utf8_binary, utf8_lcase), 8, collate(utf8_lcase, utf8_lcase))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select lpad(utf8_binary, 8, 'a'), lpad(utf8_lcase, 8, 'a') from t5 +-- !query analysis +Project [lpad(utf8_binary#x, 8, a) AS lpad(utf8_binary, 8, a)#x, lpad(utf8_lcase#x, 8, cast(a as string collate UTF8_LCASE)) AS lpad(utf8_lcase, 8, a)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select lpad(utf8_binary, 8, 'AaAA' collate utf8_lcase), lpad(utf8_lcase, 8, 'AAa' collate utf8_binary) from t5 +-- !query analysis +Project [lpad(cast(utf8_binary#x as string collate UTF8_LCASE), 8, collate(AaAA, utf8_lcase)) AS lpad(utf8_binary, 8, collate(AaAA, utf8_lcase))#x, lpad(cast(utf8_lcase#x as string), 8, collate(AAa, utf8_binary)) AS lpad(utf8_lcase, 8, collate(AAa, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select locate(utf8_binary, utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select locate(s, utf8_binary) from t5 +-- !query analysis +Project [locate(s#x, utf8_binary#x, 1) AS locate(s, utf8_binary, 1)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select locate(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select locate(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [locate(utf8_binary#x, collate(utf8_lcase#x, utf8_binary), 1) AS locate(utf8_binary, collate(utf8_lcase, utf8_binary), 1)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select locate(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 3) from t5 +-- !query analysis +Project [locate(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase), 3) AS locate(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase), 3)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"locate(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 3)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 79, + "fragment" : "locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3)" + } ] +} + + +-- !query +select locate(utf8_binary, 'a'), locate(utf8_lcase, 'a') from t5 +-- !query analysis +Project [locate(utf8_binary#x, a, 1) AS locate(utf8_binary, a, 1)#x, locate(utf8_lcase#x, cast(a as string collate UTF8_LCASE), 1) AS locate(utf8_lcase, a, 1)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select locate(utf8_binary, 'AaAA' collate utf8_lcase, 4), locate(utf8_lcase, 'AAa' collate utf8_binary, 4) from t5 +-- !query analysis +Project [locate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(AaAA, utf8_lcase), 4) AS locate(utf8_binary, collate(AaAA, utf8_lcase), 4)#x, locate(cast(utf8_lcase#x as string), collate(AAa, utf8_binary), 4) AS locate(utf8_lcase, collate(AAa, utf8_binary), 4)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select TRIM(utf8_binary, utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select TRIM(s, utf8_binary) from t5 +-- !query analysis +Project [trim(utf8_binary#x, Some(s#x)) AS TRIM(BOTH s FROM utf8_binary)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select TRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string collate UTF8_LCASE`, `string`" + } +} + + +-- !query +select TRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [trim(collate(utf8_lcase#x, utf8_binary), Some(utf8_binary#x)) AS TRIM(BOTH utf8_binary FROM collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select TRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query analysis +Project [trim(collate(utf8_lcase#x, utf8_lcase), Some(collate(utf8_binary#x, utf8_lcase))) AS TRIM(BOTH collate(utf8_binary, utf8_lcase) FROM collate(utf8_lcase, utf8_lcase))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(BOTH collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 74, + "fragment" : "TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + +-- !query +select TRIM('ABc', utf8_binary), TRIM('ABc', utf8_lcase) from t5 +-- !query analysis +Project [trim(utf8_binary#x, Some(ABc)) AS TRIM(BOTH ABc FROM utf8_binary)#x, trim(utf8_lcase#x, Some(cast(ABc as string collate UTF8_LCASE))) AS TRIM(BOTH ABc FROM utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select TRIM('ABc' collate utf8_lcase, utf8_binary), TRIM('AAa' collate utf8_binary, utf8_lcase) from t5 +-- !query analysis +Project [trim(cast(utf8_binary#x as string collate UTF8_LCASE), Some(collate(ABc, utf8_lcase))) AS TRIM(BOTH collate(ABc, utf8_lcase) FROM utf8_binary)#x, trim(cast(utf8_lcase#x as string), Some(collate(AAa, utf8_binary))) AS TRIM(BOTH collate(AAa, utf8_binary) FROM utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select BTRIM(utf8_binary, utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select BTRIM(s, utf8_binary) from t5 +-- !query analysis +Project [btrim(s#x, utf8_binary#x) AS btrim(s, utf8_binary)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select BTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select BTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [btrim(utf8_binary#x, collate(utf8_lcase#x, utf8_binary)) AS btrim(utf8_binary, collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select BTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query analysis +Project [btrim(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase)) AS btrim(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(BOTH collate(utf8_lcase, unicode_ai) FROM collate(utf8_binary, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + +-- !query +select BTRIM('ABc', utf8_binary), BTRIM('ABc', utf8_lcase) from t5 +-- !query analysis +Project [btrim(ABc, utf8_binary#x) AS btrim(ABc, utf8_binary)#x, btrim(ABc, utf8_lcase#x) AS btrim(ABc, utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select BTRIM('ABc' collate utf8_lcase, utf8_binary), BTRIM('AAa' collate utf8_binary, utf8_lcase) from t5 +-- !query analysis +Project [btrim(collate(ABc, utf8_lcase), utf8_binary#x) AS btrim(collate(ABc, utf8_lcase), utf8_binary)#x, btrim(collate(AAa, utf8_binary), utf8_lcase#x) AS btrim(collate(AAa, utf8_binary), utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select LTRIM(utf8_binary, utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select LTRIM(s, utf8_binary) from t5 +-- !query analysis +Project [ltrim(utf8_binary#x, Some(s#x)) AS TRIM(LEADING s FROM utf8_binary)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select LTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string collate UTF8_LCASE`, `string`" + } +} + + +-- !query +select LTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [ltrim(collate(utf8_lcase#x, utf8_binary), Some(utf8_binary#x)) AS TRIM(LEADING utf8_binary FROM collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select LTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query analysis +Project [ltrim(collate(utf8_lcase#x, utf8_lcase), Some(collate(utf8_binary#x, utf8_lcase))) AS TRIM(LEADING collate(utf8_binary, utf8_lcase) FROM collate(utf8_lcase, utf8_lcase))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(LEADING collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + +-- !query +select LTRIM('ABc', utf8_binary), LTRIM('ABc', utf8_lcase) from t5 +-- !query analysis +Project [ltrim(utf8_binary#x, Some(ABc)) AS TRIM(LEADING ABc FROM utf8_binary)#x, ltrim(utf8_lcase#x, Some(cast(ABc as string collate UTF8_LCASE))) AS TRIM(LEADING ABc FROM utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select LTRIM('ABc' collate utf8_lcase, utf8_binary), LTRIM('AAa' collate utf8_binary, utf8_lcase) from t5 +-- !query analysis +Project [ltrim(cast(utf8_binary#x as string collate UTF8_LCASE), Some(collate(ABc, utf8_lcase))) AS TRIM(LEADING collate(ABc, utf8_lcase) FROM utf8_binary)#x, ltrim(cast(utf8_lcase#x as string), Some(collate(AAa, utf8_binary))) AS TRIM(LEADING collate(AAa, utf8_binary) FROM utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select RTRIM(utf8_binary, utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select RTRIM(s, utf8_binary) from t5 +-- !query analysis +Project [rtrim(utf8_binary#x, Some(s#x)) AS TRIM(TRAILING s FROM utf8_binary)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select RTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string collate UTF8_LCASE`, `string`" + } +} + + +-- !query +select RTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query analysis +Project [rtrim(collate(utf8_lcase#x, utf8_binary), Some(utf8_binary#x)) AS TRIM(TRAILING utf8_binary FROM collate(utf8_lcase, utf8_binary))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select RTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query analysis +Project [rtrim(collate(utf8_lcase#x, utf8_lcase), Some(collate(utf8_binary#x, utf8_lcase))) AS TRIM(TRAILING collate(utf8_binary, utf8_lcase) FROM collate(utf8_lcase, utf8_lcase))#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(TRAILING collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + +-- !query +select RTRIM('ABc', utf8_binary), RTRIM('ABc', utf8_lcase) from t5 +-- !query analysis +Project [rtrim(utf8_binary#x, Some(ABc)) AS TRIM(TRAILING ABc FROM utf8_binary)#x, rtrim(utf8_lcase#x, Some(cast(ABc as string collate UTF8_LCASE))) AS TRIM(TRAILING ABc FROM utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +select RTRIM('ABc' collate utf8_lcase, utf8_binary), RTRIM('AAa' collate utf8_binary, utf8_lcase) from t5 +-- !query analysis +Project [rtrim(cast(utf8_binary#x as string collate UTF8_LCASE), Some(collate(ABc, utf8_lcase))) AS TRIM(TRAILING collate(ABc, utf8_lcase) FROM utf8_binary)#x, rtrim(cast(utf8_lcase#x as string), Some(collate(AAa, utf8_binary))) AS TRIM(TRAILING collate(AAa, utf8_binary) FROM utf8_lcase)#x] ++- SubqueryAlias spark_catalog.default.t5 + +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet + + +-- !query +drop table t5 +-- !query analysis +DropTable false, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t5 + + +-- !query +drop table t6 +-- !query analysis +DropTable false, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t6 + + +-- !query +drop table t7 +-- !query analysis +DropTable false, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t7 + + +-- !query +drop table t8 +-- !query analysis +DropTable false, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t8 + + +-- !query +drop table t9 +-- !query analysis +DropTable false, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t9 + + +-- !query +drop table t10 +-- !query analysis +DropTable false, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t10 diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/date.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/date.sql.out index 48137e06467e8..88c7d7b4e7d72 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/date.sql.out @@ -811,7 +811,7 @@ Project [to_date(26/October/2015, Some(dd/MMMMM/yyyy), Some(America/Los_Angeles) -- !query select from_json('{"d":"26/October/2015"}', 'd Date', map('dateFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"d":"26/October/2015"})#x] +Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"d":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/datetime-legacy.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/datetime-legacy.sql.out index 1e49f4df8267a..4221db822d024 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/datetime-legacy.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/datetime-legacy.sql.out @@ -811,7 +811,7 @@ Project [to_date(26/October/2015, Some(dd/MMMMM/yyyy), Some(America/Los_Angeles) -- !query select from_json('{"d":"26/October/2015"}', 'd Date', map('dateFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"d":"26/October/2015"})#x] +Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"d":"26/October/2015"})#x] +- OneRowRelation @@ -1833,7 +1833,7 @@ Project [unix_timestamp(22 05 2020 Friday, dd MM yyyy EEEEE, Some(America/Los_An -- !query select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"t":"26/October/2015"})#x] +Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"t":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/execute-immediate.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/execute-immediate.sql.out index 78bf1ccb1678c..ce510527c8781 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/execute-immediate.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/execute-immediate.sql.out @@ -471,7 +471,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'invalid_cast_error_expected'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -662,7 +661,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'name1'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out index 290e55052931d..efa149509751d 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out @@ -981,9 +981,13 @@ select interval '20 15:40:32.99899999' day to hour -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR` when cast to interval day to hour: 20 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "20 15:40:32.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR`", + "typeName" : "interval day to hour" }, "queryContext" : [ { "objectType" : "", @@ -1000,9 +1004,13 @@ select interval '20 15:40:32.99899999' day to minute -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h:m`, `INTERVAL [+|-]'[+|-]d h:m' DAY TO MINUTE` when cast to interval day to minute: 20 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "20 15:40:32.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h:m`, `INTERVAL [+|-]'[+|-]d h:m' DAY TO MINUTE`", + "typeName" : "interval day to minute" }, "queryContext" : [ { "objectType" : "", @@ -1019,9 +1027,13 @@ select interval '15:40:32.99899999' hour to minute -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE` when cast to interval hour to minute: 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "15:40:32.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE`", + "typeName" : "interval hour to minute" }, "queryContext" : [ { "objectType" : "", @@ -1038,9 +1050,13 @@ select interval '15:40.99899999' hour to second -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND` when cast to interval hour to second: 15:40.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "15:40.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND`", + "typeName" : "interval hour to second" }, "queryContext" : [ { "objectType" : "", @@ -1057,9 +1073,13 @@ select interval '15:40' hour to second -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND` when cast to interval hour to second: 15:40, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "15:40", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND`", + "typeName" : "interval hour to second" }, "queryContext" : [ { "objectType" : "", @@ -1076,9 +1096,13 @@ select interval '20 40:32.99899999' minute to second -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND` when cast to interval minute to second: 20 40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "20 40:32.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND`", + "typeName" : "interval minute to second" }, "queryContext" : [ { "objectType" : "", @@ -1460,9 +1484,11 @@ SELECT INTERVAL '178956970-8' YEAR TO MONTH -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.INTERVAL_PARSING", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Error parsing interval year-month string: integer overflow" + "input" : "178956970-8", + "interval" : "year-month" }, "queryContext" : [ { "objectType" : "", @@ -1909,9 +1935,13 @@ select interval '-\t2-2\t' year to month -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match year-month format of `[+|-]y-m`, `INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH` when cast to interval year to month: -\t2-2\t" + "input" : "-\t2-2\t", + "intervalStr" : "year-month", + "supportedFormat" : "`[+|-]y-m`, `INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH`", + "typeName" : "interval year to month" }, "queryContext" : [ { "objectType" : "", @@ -1935,9 +1965,13 @@ select interval '\n-\t10\t 12:34:46.789\t' day to second -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h:m:s.n`, `INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND` when cast to interval day to second: \n-\t10\t 12:34:46.789\t, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "\n-\t10\t 12:34:46.789\t", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h:m:s.n`, `INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND`", + "typeName" : "interval day to second" }, "queryContext" : [ { "objectType" : "", @@ -2074,7 +2108,7 @@ SELECT to_csv(named_struct('a', interval 32 year, 'b', interval 10 month)), from_csv(to_csv(named_struct('a', interval 32 year, 'b', interval 10 month)), 'a interval year, b interval month') -- !query analysis -Project [from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles)) AS from_json({"a":"1 days"})#x, from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None) AS from_csv(1, 1)#x, to_json(from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1 days"}))#x, to_csv(from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None), Some(America/Los_Angeles)) AS to_csv(from_csv(1, 1))#x, to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)) AS to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH))#x, from_csv(StructField(a,YearMonthIntervalType(0,0),true), StructField(b,YearMonthIntervalType(1,1),true), to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), None) AS from_csv(to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH)))#x] +Project [from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles), false) AS from_json({"a":"1 days"})#x, from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None) AS from_csv(1, 1)#x, to_json(from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1 days"}))#x, to_csv(from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None), Some(America/Los_Angeles)) AS to_csv(from_csv(1, 1))#x, to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)) AS to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH))#x, from_csv(StructField(a,YearMonthIntervalType(0,0),true), StructField(b,YearMonthIntervalType(1,1),true), to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), None) AS from_csv(to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH)))#x] +- OneRowRelation @@ -2085,7 +2119,7 @@ SELECT to_json(map('a', interval 100 day 130 minute)), from_json(to_json(map('a', interval 100 day 130 minute)), 'a interval day to minute') -- !query analysis -Project [from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE))#x, from_json(StructField(a,DayTimeIntervalType(0,2),true), to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS from_json(to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE)))#x] +Project [from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE))#x, from_json(StructField(a,DayTimeIntervalType(0,2),true), to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)), Some(America/Los_Angeles), false) AS from_json(to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE)))#x] +- OneRowRelation @@ -2096,7 +2130,7 @@ SELECT to_json(map('a', interval 32 year 10 month)), from_json(to_json(map('a', interval 32 year 10 month)), 'a interval year to month') -- !query analysis -Project [from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '32-10' YEAR TO MONTH))#x, from_json(StructField(a,YearMonthIntervalType(0,1),true), to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS from_json(to_json(map(a, INTERVAL '32-10' YEAR TO MONTH)))#x] +Project [from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '32-10' YEAR TO MONTH))#x, from_json(StructField(a,YearMonthIntervalType(0,1),true), to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), false) AS from_json(to_json(map(a, INTERVAL '32-10' YEAR TO MONTH)))#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out index e81ee769f57d6..5bf893605423c 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out @@ -3017,6 +3017,53 @@ Project [c1#x, c2#x, t#x] +- LocalRelation [col1#x, col2#x] +-- !query +select 1 +from t1 as t_outer +left join + lateral( + select b1,b2 + from + ( + select + t2.c1 as b1, + 1 as b2 + from t2 + union + select t_outer.c1 as b1, + null as b2 + ) as t_inner + where (t_inner.b1 < t_outer.c2 or t_inner.b1 is null) + and t_inner.b1 = t_outer.c1 + order by t_inner.b1,t_inner.b2 desc limit 1 + ) as lateral_table +-- !query analysis +Project [1 AS 1#x] ++- LateralJoin lateral-subquery#x [c2#x && c1#x && c1#x], LeftOuter + : +- SubqueryAlias lateral_table + : +- GlobalLimit 1 + : +- LocalLimit 1 + : +- Sort [b1#x ASC NULLS FIRST, b2#x DESC NULLS LAST], true + : +- Project [b1#x, b2#x] + : +- Filter (((b1#x < outer(c2#x)) OR isnull(b1#x)) AND (b1#x = outer(c1#x))) + : +- SubqueryAlias t_inner + : +- Distinct + : +- Union false, false + : :- Project [c1#x AS b1#x, 1 AS b2#x] + : : +- SubqueryAlias spark_catalog.default.t2 + : : +- View (`spark_catalog`.`default`.`t2`, [c1#x, c2#x]) + : : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + : : +- LocalRelation [col1#x, col2#x] + : +- Project [b1#x, cast(b2#x as int) AS b2#x] + : +- Project [outer(c1#x) AS b1#x, null AS b2#x] + : +- OneRowRelation + +- SubqueryAlias t_outer + +- SubqueryAlias spark_catalog.default.t1 + +- View (`spark_catalog`.`default`.`t1`, [c1#x, c2#x]) + +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + +- LocalRelation [col1#x, col2#x] + + -- !query DROP VIEW t1 -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/json-functions.sql.out index 0d7c6b2056231..fef9d0c5b6250 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/json-functions.sql.out @@ -118,14 +118,14 @@ org.apache.spark.sql.AnalysisException -- !query select from_json('{"a":1}', 'a INT') -- !query analysis -Project [from_json(StructField(a,IntegerType,true), {"a":1}, Some(America/Los_Angeles)) AS from_json({"a":1})#x] +Project [from_json(StructField(a,IntegerType,true), {"a":1}, Some(America/Los_Angeles), false) AS from_json({"a":1})#x] +- OneRowRelation -- !query select from_json('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')) -- !query analysis -Project [from_json(StructField(time,TimestampType,true), (timestampFormat,dd/MM/yyyy), {"time":"26/08/2015"}, Some(America/Los_Angeles)) AS from_json({"time":"26/08/2015"})#x] +Project [from_json(StructField(time,TimestampType,true), (timestampFormat,dd/MM/yyyy), {"time":"26/08/2015"}, Some(America/Los_Angeles), false) AS from_json({"time":"26/08/2015"})#x] +- OneRowRelation @@ -279,14 +279,14 @@ DropTempViewCommand jsonTable -- !query select from_json('{"a":1, "b":2}', 'map') -- !query analysis -Project [from_json(MapType(StringType,IntegerType,true), {"a":1, "b":2}, Some(America/Los_Angeles)) AS entries#x] +Project [from_json(MapType(StringType,IntegerType,true), {"a":1, "b":2}, Some(America/Los_Angeles), false) AS entries#x] +- OneRowRelation -- !query select from_json('{"a":1, "b":"2"}', 'struct') -- !query analysis -Project [from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), {"a":1, "b":"2"}, Some(America/Los_Angeles)) AS from_json({"a":1, "b":"2"})#x] +Project [from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), {"a":1, "b":"2"}, Some(America/Los_Angeles), false) AS from_json({"a":1, "b":"2"})#x] +- OneRowRelation @@ -300,70 +300,70 @@ Project [schema_of_json({"c1":0, "c2":[1]}) AS schema_of_json({"c1":0, "c2":[1]} -- !query select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')) -- !query analysis -Project [from_json(StructField(c1,ArrayType(LongType,true),true), {"c1":[1, 2, 3]}, Some(America/Los_Angeles)) AS from_json({"c1":[1, 2, 3]})#x] +Project [from_json(StructField(c1,ArrayType(LongType,true),true), {"c1":[1, 2, 3]}, Some(America/Los_Angeles), false) AS from_json({"c1":[1, 2, 3]})#x] +- OneRowRelation -- !query select from_json('[1, 2, 3]', 'array') -- !query analysis -Project [from_json(ArrayType(IntegerType,true), [1, 2, 3], Some(America/Los_Angeles)) AS from_json([1, 2, 3])#x] +Project [from_json(ArrayType(IntegerType,true), [1, 2, 3], Some(America/Los_Angeles), false) AS from_json([1, 2, 3])#x] +- OneRowRelation -- !query select from_json('[1, "2", 3]', 'array') -- !query analysis -Project [from_json(ArrayType(IntegerType,true), [1, "2", 3], Some(America/Los_Angeles)) AS from_json([1, "2", 3])#x] +Project [from_json(ArrayType(IntegerType,true), [1, "2", 3], Some(America/Los_Angeles), false) AS from_json([1, "2", 3])#x] +- OneRowRelation -- !query select from_json('[1, 2, null]', 'array') -- !query analysis -Project [from_json(ArrayType(IntegerType,true), [1, 2, null], Some(America/Los_Angeles)) AS from_json([1, 2, null])#x] +Project [from_json(ArrayType(IntegerType,true), [1, 2, null], Some(America/Los_Angeles), false) AS from_json([1, 2, null])#x] +- OneRowRelation -- !query select from_json('[{"a": 1}, {"a":2}]', 'array>') -- !query analysis -Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), [{"a": 1}, {"a":2}], Some(America/Los_Angeles)) AS from_json([{"a": 1}, {"a":2}])#x] +Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), [{"a": 1}, {"a":2}], Some(America/Los_Angeles), false) AS from_json([{"a": 1}, {"a":2}])#x] +- OneRowRelation -- !query select from_json('{"a": 1}', 'array>') -- !query analysis -Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), {"a": 1}, Some(America/Los_Angeles)) AS from_json({"a": 1})#x] +Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), {"a": 1}, Some(America/Los_Angeles), false) AS from_json({"a": 1})#x] +- OneRowRelation -- !query select from_json('[null, {"a":2}]', 'array>') -- !query analysis -Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), [null, {"a":2}], Some(America/Los_Angeles)) AS from_json([null, {"a":2}])#x] +Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), [null, {"a":2}], Some(America/Los_Angeles), false) AS from_json([null, {"a":2}])#x] +- OneRowRelation -- !query select from_json('[{"a": 1}, {"b":2}]', 'array>') -- !query analysis -Project [from_json(ArrayType(MapType(StringType,IntegerType,true),true), [{"a": 1}, {"b":2}], Some(America/Los_Angeles)) AS from_json([{"a": 1}, {"b":2}])#x] +Project [from_json(ArrayType(MapType(StringType,IntegerType,true),true), [{"a": 1}, {"b":2}], Some(America/Los_Angeles), false) AS from_json([{"a": 1}, {"b":2}])#x] +- OneRowRelation -- !query select from_json('[{"a": 1}, 2]', 'array>') -- !query analysis -Project [from_json(ArrayType(MapType(StringType,IntegerType,true),true), [{"a": 1}, 2], Some(America/Los_Angeles)) AS from_json([{"a": 1}, 2])#x] +Project [from_json(ArrayType(MapType(StringType,IntegerType,true),true), [{"a": 1}, 2], Some(America/Los_Angeles), false) AS from_json([{"a": 1}, 2])#x] +- OneRowRelation -- !query select from_json('{"d": "2012-12-15", "t": "2012-12-15 15:15:15"}', 'd date, t timestamp') -- !query analysis -Project [from_json(StructField(d,DateType,true), StructField(t,TimestampType,true), {"d": "2012-12-15", "t": "2012-12-15 15:15:15"}, Some(America/Los_Angeles)) AS from_json({"d": "2012-12-15", "t": "2012-12-15 15:15:15"})#x] +Project [from_json(StructField(d,DateType,true), StructField(t,TimestampType,true), {"d": "2012-12-15", "t": "2012-12-15 15:15:15"}, Some(America/Los_Angeles), false) AS from_json({"d": "2012-12-15", "t": "2012-12-15 15:15:15"})#x] +- OneRowRelation @@ -373,7 +373,7 @@ select from_json( 'd date, t timestamp', map('dateFormat', 'MM/dd yyyy', 'timestampFormat', 'MM/dd yyyy HH:mm:ss')) -- !query analysis -Project [from_json(StructField(d,DateType,true), StructField(t,TimestampType,true), (dateFormat,MM/dd yyyy), (timestampFormat,MM/dd yyyy HH:mm:ss), {"d": "12/15 2012", "t": "12/15 2012 15:15:15"}, Some(America/Los_Angeles)) AS from_json({"d": "12/15 2012", "t": "12/15 2012 15:15:15"})#x] +Project [from_json(StructField(d,DateType,true), StructField(t,TimestampType,true), (dateFormat,MM/dd yyyy), (timestampFormat,MM/dd yyyy HH:mm:ss), {"d": "12/15 2012", "t": "12/15 2012 15:15:15"}, Some(America/Los_Angeles), false) AS from_json({"d": "12/15 2012", "t": "12/15 2012 15:15:15"})#x] +- OneRowRelation @@ -383,7 +383,7 @@ select from_json( 'd date', map('dateFormat', 'MM-dd')) -- !query analysis -Project [from_json(StructField(d,DateType,true), (dateFormat,MM-dd), {"d": "02-29"}, Some(America/Los_Angeles)) AS from_json({"d": "02-29"})#x] +Project [from_json(StructField(d,DateType,true), (dateFormat,MM-dd), {"d": "02-29"}, Some(America/Los_Angeles), false) AS from_json({"d": "02-29"})#x] +- OneRowRelation @@ -393,7 +393,7 @@ select from_json( 't timestamp', map('timestampFormat', 'MM-dd')) -- !query analysis -Project [from_json(StructField(t,TimestampType,true), (timestampFormat,MM-dd), {"t": "02-29"}, Some(America/Los_Angeles)) AS from_json({"t": "02-29"})#x] +Project [from_json(StructField(t,TimestampType,true), (timestampFormat,MM-dd), {"t": "02-29"}, Some(America/Los_Angeles), false) AS from_json({"t": "02-29"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/parse-schema-string.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/parse-schema-string.sql.out index 45fc3bd03a782..ae8e47ed3665c 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/parse-schema-string.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/parse-schema-string.sql.out @@ -16,12 +16,12 @@ Project [from_csv(StructField(cube,IntegerType,true), 1, Some(America/Los_Angele -- !query select from_json('{"create":1}', 'create INT') -- !query analysis -Project [from_json(StructField(create,IntegerType,true), {"create":1}, Some(America/Los_Angeles)) AS from_json({"create":1})#x] +Project [from_json(StructField(create,IntegerType,true), {"create":1}, Some(America/Los_Angeles), false) AS from_json({"create":1})#x] +- OneRowRelation -- !query select from_json('{"cube":1}', 'cube INT') -- !query analysis -Project [from_json(StructField(cube,IntegerType,true), {"cube":1}, Some(America/Los_Angeles)) AS from_json({"cube":1})#x] +Project [from_json(StructField(cube,IntegerType,true), {"cube":1}, Some(America/Los_Angeles), false) AS from_json({"cube":1})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out new file mode 100644 index 0000000000000..c44ce153a2f41 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out @@ -0,0 +1,590 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +drop table if exists t +-- !query analysis +DropTable true, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t + + +-- !query +create table t(x int, y string) using csv +-- !query analysis +CreateDataSourceTableCommand `spark_catalog`.`default`.`t`, false + + +-- !query +insert into t values (0, 'abc'), (1, 'def') +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t, false, CSV, [path=file:[not included in comparison]/{warehouse_dir}/t], Append, `spark_catalog`.`default`.`t`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t), [x, y] ++- Project [cast(col1#x as int) AS x#x, cast(col2#x as string) AS y#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +drop table if exists other +-- !query analysis +DropTable true, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.other + + +-- !query +create table other(a int, b int) using json +-- !query analysis +CreateDataSourceTableCommand `spark_catalog`.`default`.`other`, false + + +-- !query +insert into other values (1, 1), (1, 2), (2, 4) +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/other, false, JSON, [path=file:[not included in comparison]/{warehouse_dir}/other], Append, `spark_catalog`.`default`.`other`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/other), [a, b] ++- Project [cast(col1#x as int) AS a#x, cast(col2#x as int) AS b#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +drop table if exists st +-- !query analysis +DropTable true, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.st + + +-- !query +create table st(x int, col struct) using parquet +-- !query analysis +CreateDataSourceTableCommand `spark_catalog`.`default`.`st`, false + + +-- !query +insert into st values (1, (2, 3)) +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/st, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/st], Append, `spark_catalog`.`default`.`st`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/st), [x, col] ++- Project [cast(col1#x as int) AS x#x, named_struct(i1, cast(col2#x.col1 as int), i2, cast(col2#x.col2 as int)) AS col#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +table t +|> select 1 as x +-- !query analysis +Project [pipeselect(1) AS x#x] ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select x, y +-- !query analysis +Project [x#x, y#x] ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select x, y +|> select x + length(y) as z +-- !query analysis +Project [pipeselect((x#x + length(y#x))) AS z#x] ++- Project [x#x, y#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +values (0), (1) tab(col) +|> select col * 2 as result +-- !query analysis +Project [pipeselect((col#x * 2)) AS result#x] ++- SubqueryAlias tab + +- LocalRelation [col#x] + + +-- !query +(select * from t union all select * from t) +|> select x + length(y) as result +-- !query analysis +Project [pipeselect((x#x + length(y#x))) AS result#x] ++- Union false, false + :- Project [x#x, y#x] + : +- SubqueryAlias spark_catalog.default.t + : +- Relation spark_catalog.default.t[x#x,y#x] csv + +- Project [x#x, y#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(table t + |> select x, y + |> select x) +union all +select x from t where x < 1 +-- !query analysis +Union false, false +:- Project [x#x] +: +- Project [x#x, y#x] +: +- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- Project [x#x] + +- Filter (x#x < 1) + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(select col from st) +|> select col.i1 +-- !query analysis +Project [col#x.i1 AS i1#x] ++- Project [col#x] + +- SubqueryAlias spark_catalog.default.st + +- Relation spark_catalog.default.st[x#x,col#x] parquet + + +-- !query +table st +|> select st.col.i1 +-- !query analysis +Project [col#x.i1 AS i1#x] ++- SubqueryAlias spark_catalog.default.st + +- Relation spark_catalog.default.st[x#x,col#x] parquet + + +-- !query +table t +|> select (select a from other where x = a limit 1) as result +-- !query analysis +Project [pipeselect(scalar-subquery#x [x#x]) AS result#x] +: +- GlobalLimit 1 +: +- LocalLimit 1 +: +- Project [a#x] +: +- Filter (outer(x#x) = a#x) +: +- SubqueryAlias spark_catalog.default.other +: +- Relation spark_catalog.default.other[a#x,b#x] json ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +select (values (0) tab(col) |> select col) as result +-- !query analysis +Project [scalar-subquery#x [] AS result#x] +: +- Project [col#x] +: +- SubqueryAlias tab +: +- LocalRelation [col#x] ++- OneRowRelation + + +-- !query +table t +|> select (select any_value(a) from other where x = a limit 1) as result +-- !query analysis +Project [pipeselect(scalar-subquery#x [x#x]) AS result#x] +: +- GlobalLimit 1 +: +- LocalLimit 1 +: +- Aggregate [any_value(a#x, false) AS any_value(a)#x] +: +- Filter (outer(x#x) = a#x) +: +- SubqueryAlias spark_catalog.default.other +: +- Relation spark_catalog.default.other[a#x,b#x] json ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select x + length(x) as z, z + 1 as plus_one +-- !query analysis +Project [z#x, pipeselect((z#x + 1)) AS plus_one#x] ++- Project [x#x, y#x, pipeselect((x#x + length(cast(x#x as string)))) AS z#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select first_value(x) over (partition by y) as result +-- !query analysis +Project [result#x] ++- Project [x#x, y#x, _we0#x, pipeselect(_we0#x) AS result#x] + +- Window [first_value(x#x, false) windowspecdefinition(y#x, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#x], [y#x] + +- Project [x#x, y#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +select 1 x, 2 y, 3 z +|> select 1 + sum(x) over (), + avg(y) over (), + x, + avg(x+1) over (partition by y order by z) AS a2 +|> select a2 +-- !query analysis +Project [a2#x] ++- Project [(1 + sum(x) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING))#xL, avg(y) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#x, x#x, a2#x] + +- Project [x#x, y#x, _w1#x, z#x, _we0#xL, avg(y) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#x, _we2#x, (cast(1 as bigint) + _we0#xL) AS (1 + sum(x) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING))#xL, avg(y) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#x, pipeselect(_we2#x) AS a2#x] + +- Window [avg(_w1#x) windowspecdefinition(y#x, z#x ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS _we2#x], [y#x], [z#x ASC NULLS FIRST] + +- Window [sum(x#x) windowspecdefinition(specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#xL, avg(y#x) windowspecdefinition(specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS avg(y) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#x] + +- Project [x#x, y#x, (x#x + 1) AS _w1#x, z#x] + +- Project [1 AS x#x, 2 AS y#x, 3 AS z#x] + +- OneRowRelation + + +-- !query +table t +|> select x, count(*) over () +|> select x +-- !query analysis +Project [x#x] ++- Project [x#x, count(1) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#xL] + +- Project [x#x, count(1) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#xL, count(1) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#xL] + +- Window [count(1) windowspecdefinition(specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS count(1) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#xL] + +- Project [x#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select distinct x, y +-- !query analysis +Distinct ++- Project [x#x, y#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select * +-- !query analysis +Project [x#x, y#x] ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select * except (y) +-- !query analysis +Project [x#x] ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select /*+ repartition(3) */ * +-- !query analysis +Repartition 3, true ++- Project [x#x, y#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select /*+ repartition(3) */ distinct x +-- !query analysis +Repartition 3, true ++- Distinct + +- Project [x#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select /*+ repartition(3) */ all x +-- !query analysis +Repartition 3, true ++- Project [x#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select sum(x) as result +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "PIPE_OPERATOR_SELECT_CONTAINS_AGGREGATE_FUNCTION", + "sqlState" : "0A000", + "messageParameters" : { + "expr" : "sum(x#x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 19, + "stopIndex" : 24, + "fragment" : "sum(x)" + } ] +} + + +-- !query +table t +|> select y, length(y) + sum(x) as result +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "PIPE_OPERATOR_SELECT_CONTAINS_AGGREGATE_FUNCTION", + "sqlState" : "0A000", + "messageParameters" : { + "expr" : "sum(x#x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 34, + "stopIndex" : 39, + "fragment" : "sum(x)" + } ] +} + + +-- !query +table t +|> where true +-- !query analysis +Filter true ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where x + length(y) < 4 +-- !query analysis +Filter ((x#x + length(y#x)) < 4) ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where x + length(y) < 4 +|> where x + length(y) < 3 +-- !query analysis +Filter ((x#x + length(y#x)) < 3) ++- SubqueryAlias __auto_generated_subquery_name + +- Filter ((x#x + length(y#x)) < 4) + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where x = 1 +-- !query analysis +Filter (x#x = 1) ++- SubqueryAlias __auto_generated_subquery_name + +- Aggregate [x#x], [x#x, sum(length(y#x)) AS sum_len#xL] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where t.x = 1 +-- !query analysis +Filter (x#x = 1) ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where spark_catalog.default.t.x = 1 +-- !query analysis +Filter (x#x = 1) ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(select col from st) +|> where col.i1 = 1 +-- !query analysis +Filter (col#x.i1 = 1) ++- SubqueryAlias __auto_generated_subquery_name + +- Project [col#x] + +- SubqueryAlias spark_catalog.default.st + +- Relation spark_catalog.default.st[x#x,col#x] parquet + + +-- !query +table st +|> where st.col.i1 = 2 +-- !query analysis +Filter (col#x.i1 = 2) ++- SubqueryAlias spark_catalog.default.st + +- Relation spark_catalog.default.st[x#x,col#x] parquet + + +-- !query +table t +|> where exists (select a from other where x = a limit 1) +-- !query analysis +Filter exists#x [x#x] +: +- GlobalLimit 1 +: +- LocalLimit 1 +: +- Project [a#x] +: +- Filter (outer(x#x) = a#x) +: +- SubqueryAlias spark_catalog.default.other +: +- Relation spark_catalog.default.other[a#x,b#x] json ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where (select any_value(a) from other where x = a limit 1) = 1 +-- !query analysis +Filter (scalar-subquery#x [x#x] = 1) +: +- GlobalLimit 1 +: +- LocalLimit 1 +: +- Aggregate [any_value(a#x, false) AS any_value(a)#x] +: +- Filter (outer(x#x) = a#x) +: +- SubqueryAlias spark_catalog.default.other +: +- Relation spark_catalog.default.other[a#x,b#x] json ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where sum(x) = 1 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"(sum(x) = 1)\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 27, + "fragment" : "table t\n|> where sum(x) = 1" + } ] +} + + +-- !query +table t +|> where y = 'abc' or length(y) + sum(x) = 1 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"((y = abc) OR ((length(y) + sum(x)) = 1))\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 52, + "fragment" : "table t\n|> where y = 'abc' or length(y) + sum(x) = 1" + } ] +} + + +-- !query +table t +|> where first_value(x) over (partition by y) = 1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +select * from t where first_value(x) over (partition by y) = 1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +table t +|> select x, length(y) as z +|> where x + length(y) < 4 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `z`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 57, + "stopIndex" : 57, + "fragment" : "y" + } ] +} + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where sum(length(y)) = 3 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `sum_len`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 77, + "stopIndex" : 77, + "fragment" : "y" + } ] +} + + +-- !query +drop table t +-- !query analysis +DropTable false, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t + + +-- !query +drop table other +-- !query analysis +DropTable false, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.other + + +-- !query +drop table st +-- !query analysis +DropTable false, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.st diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/interval.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/interval.sql.out index 8d41651cb743a..1add0830d9b77 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/interval.sql.out @@ -88,9 +88,13 @@ SELECT interval '1 2:03' day to hour -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR` when cast to interval day to hour: 1 2:03, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR`", + "typeName" : "interval day to hour" }, "queryContext" : [ { "objectType" : "", @@ -107,9 +111,13 @@ SELECT interval '1 2:03:04' day to hour -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR` when cast to interval day to hour: 1 2:03:04, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03:04", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR`", + "typeName" : "interval day to hour" }, "queryContext" : [ { "objectType" : "", @@ -133,9 +141,13 @@ SELECT interval '1 2:03:04' day to minute -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h:m`, `INTERVAL [+|-]'[+|-]d h:m' DAY TO MINUTE` when cast to interval day to minute: 1 2:03:04, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03:04", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h:m`, `INTERVAL [+|-]'[+|-]d h:m' DAY TO MINUTE`", + "typeName" : "interval day to minute" }, "queryContext" : [ { "objectType" : "", @@ -152,9 +164,13 @@ SELECT interval '1 2:03' day to second -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h:m:s.n`, `INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND` when cast to interval day to second: 1 2:03, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h:m:s.n`, `INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND`", + "typeName" : "interval day to second" }, "queryContext" : [ { "objectType" : "", @@ -178,9 +194,13 @@ SELECT interval '1 2:03' hour to minute -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE` when cast to interval hour to minute: 1 2:03, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE`", + "typeName" : "interval hour to minute" }, "queryContext" : [ { "objectType" : "", @@ -197,9 +217,13 @@ SELECT interval '1 2:03:04' hour to minute -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE` when cast to interval hour to minute: 1 2:03:04, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03:04", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE`", + "typeName" : "interval hour to minute" }, "queryContext" : [ { "objectType" : "", @@ -216,9 +240,13 @@ SELECT interval '1 2:03' hour to second -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND` when cast to interval hour to second: 1 2:03, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND`", + "typeName" : "interval hour to second" }, "queryContext" : [ { "objectType" : "", @@ -235,9 +263,13 @@ SELECT interval '1 2:03:04' hour to second -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND` when cast to interval hour to second: 1 2:03:04, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03:04", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND`", + "typeName" : "interval hour to second" }, "queryContext" : [ { "objectType" : "", @@ -254,9 +286,13 @@ SELECT interval '1 2:03' minute to second -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND` when cast to interval minute to second: 1 2:03, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND`", + "typeName" : "interval minute to second" }, "queryContext" : [ { "objectType" : "", @@ -273,9 +309,13 @@ SELECT interval '1 2:03:04' minute to second -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND` when cast to interval minute to second: 1 2:03:04, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03:04", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND`", + "typeName" : "interval minute to second" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part2.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part2.sql.out index cdcd563de4f6a..330e1c1cad7ef 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part2.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part2.sql.out @@ -449,7 +449,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'NaN'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out index 3cacbdc141053..133cd6a60a4fb 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out @@ -93,3 +93,404 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "fragment" : "rand('1')" } ] } + + +-- !query +SELECT uniform(0, 1, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(0, 10, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(0L, 10L, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(0, 10L, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(0, 10S, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(10, 20, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(10.0F, 20.0F, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(10.0D, 20.0D, CAST(3 / 7 AS LONG)) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(10, 20.0F, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(10, 20, 0) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(10, 20.0F) IS NOT NULL AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(NULL, 1, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(0, NULL, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(0, 1, NULL) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(10, 20, col) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"col\"", + "inputName" : "seed", + "inputType" : "integer or floating-point", + "sqlExpr" : "\"uniform(10, 20, col)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 27, + "fragment" : "uniform(10, 20, col)" + } ] +} + + +-- !query +SELECT uniform(col, 10, 0) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"col\"", + "inputName" : "min", + "inputType" : "integer or floating-point", + "sqlExpr" : "\"uniform(col, 10, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 26, + "fragment" : "uniform(col, 10, 0)" + } ] +} + + +-- !query +SELECT uniform(10) AS result +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "1", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "[2, 3]", + "functionName" : "`uniform`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 18, + "fragment" : "uniform(10)" + } ] +} + + +-- !query +SELECT uniform(10, 20, 30, 40) AS result +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "4", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "[2, 3]", + "functionName" : "`uniform`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 30, + "fragment" : "uniform(10, 20, 30, 40)" + } ] +} + + +-- !query +SELECT randstr(1, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT randstr(5, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT randstr(10, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT randstr(10S, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT randstr(10, 0) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT randstr(10) IS NOT NULL AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT randstr(10L, 0) AS result +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"10\"", + "inputType" : "\"BIGINT\"", + "paramIndex" : "first", + "requiredType" : "INT or SMALLINT", + "sqlExpr" : "\"randstr(10, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 22, + "fragment" : "randstr(10L, 0)" + } ] +} + + +-- !query +SELECT randstr(10.0F, 0) AS result +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"10.0\"", + "inputType" : "\"FLOAT\"", + "paramIndex" : "first", + "requiredType" : "INT or SMALLINT", + "sqlExpr" : "\"randstr(10.0, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 24, + "fragment" : "randstr(10.0F, 0)" + } ] +} + + +-- !query +SELECT randstr(10.0D, 0) AS result +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"10.0\"", + "inputType" : "\"DOUBLE\"", + "paramIndex" : "first", + "requiredType" : "INT or SMALLINT", + "sqlExpr" : "\"randstr(10.0, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 24, + "fragment" : "randstr(10.0D, 0)" + } ] +} + + +-- !query +SELECT randstr(NULL, 0) AS result +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"NULL\"", + "inputType" : "\"VOID\"", + "paramIndex" : "first", + "requiredType" : "INT or SMALLINT", + "sqlExpr" : "\"randstr(NULL, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 23, + "fragment" : "randstr(NULL, 0)" + } ] +} + + +-- !query +SELECT randstr(0, NULL) AS result +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"NULL\"", + "inputType" : "\"VOID\"", + "paramIndex" : "second", + "requiredType" : "INT or SMALLINT", + "sqlExpr" : "\"randstr(0, NULL)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 23, + "fragment" : "randstr(0, NULL)" + } ] +} + + +-- !query +SELECT randstr(col, 0) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"col\"", + "inputName" : "length", + "inputType" : "INT or SMALLINT", + "sqlExpr" : "\"randstr(col, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 22, + "fragment" : "randstr(col, 0)" + } ] +} + + +-- !query +SELECT randstr(10, col) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"col\"", + "inputName" : "seedExpression", + "inputType" : "INT or SMALLINT", + "sqlExpr" : "\"randstr(10, col)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 23, + "fragment" : "randstr(10, col)" + } ] +} + + +-- !query +SELECT randstr(10, 0, 1) AS result +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "3", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "[1, 2]", + "functionName" : "`randstr`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 24, + "fragment" : "randstr(10, 0, 1)" + } ] +} diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out index eb48f0d9a28f0..02e7c39ae83fd 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out @@ -842,7 +842,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'hello'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -885,7 +884,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"INT\"", "targetType" : "\"SMALLINT\"", "value" : "100000" @@ -1002,7 +1000,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"DOUBLE\"", "targetType" : "\"INT\"", "value" : "1.0E10D" @@ -1062,7 +1059,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'hello'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -2151,7 +2147,7 @@ CreateVariable defaultvalueexpression(cast(a INT as string), 'a INT'), true -- !query SELECT from_json('{"a": 1}', var1) -- !query analysis -Project [from_json(StructField(a,IntegerType,true), {"a": 1}, Some(America/Los_Angeles)) AS from_json({"a": 1})#x] +Project [from_json(StructField(a,IntegerType,true), {"a": 1}, Some(America/Los_Angeles), false) AS from_json({"a": 1})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subexp-elimination.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subexp-elimination.sql.out index 94073f2751b3e..754b05bfa6fed 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subexp-elimination.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subexp-elimination.sql.out @@ -15,7 +15,7 @@ AS testData(a, b), false, true, LocalTempView, UNSUPPORTED, true -- !query SELECT from_json(a, 'struct').a, from_json(a, 'struct').b, from_json(b, 'array>')[0].a, from_json(b, 'array>')[0].b FROM testData -- !query analysis -Project [from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a AS from_json(a).a#x, from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).b AS from_json(a).b#x, from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].a AS from_json(b)[0].a#x, from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].b AS from_json(b)[0].b#x] +Project [from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a AS from_json(a).a#x, from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).b AS from_json(a).b#x, from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].a AS from_json(b)[0].a#x, from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].b AS from_json(b)[0].b#x] +- SubqueryAlias testdata +- View (`testData`, [a#x, b#x]) +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] @@ -27,7 +27,7 @@ Project [from_json(StructField(a,IntegerType,true), StructField(b,StringType,tru -- !query SELECT if(from_json(a, 'struct').a > 1, from_json(b, 'array>')[0].a, from_json(b, 'array>')[0].a + 1) FROM testData -- !query analysis -Project [if ((from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a > 1)) from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].a else (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].a + 1) AS (IF((from_json(a).a > 1), from_json(b)[0].a, (from_json(b)[0].a + 1)))#x] +Project [if ((from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a > 1)) from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].a else (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].a + 1) AS (IF((from_json(a).a > 1), from_json(b)[0].a, (from_json(b)[0].a + 1)))#x] +- SubqueryAlias testdata +- View (`testData`, [a#x, b#x]) +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] @@ -39,7 +39,7 @@ Project [if ((from_json(StructField(a,IntegerType,true), StructField(b,StringTyp -- !query SELECT if(isnull(from_json(a, 'struct').a), from_json(b, 'array>')[0].b + 1, from_json(b, 'array>')[0].b) FROM testData -- !query analysis -Project [if (isnull(from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a)) (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].b + 1) else from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].b AS (IF((from_json(a).a IS NULL), (from_json(b)[0].b + 1), from_json(b)[0].b))#x] +Project [if (isnull(from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a)) (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].b + 1) else from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].b AS (IF((from_json(a).a IS NULL), (from_json(b)[0].b + 1), from_json(b)[0].b))#x] +- SubqueryAlias testdata +- View (`testData`, [a#x, b#x]) +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] @@ -51,7 +51,7 @@ Project [if (isnull(from_json(StructField(a,IntegerType,true), StructField(b,Str -- !query SELECT case when from_json(a, 'struct').a > 5 then from_json(a, 'struct').b when from_json(a, 'struct').a > 4 then from_json(a, 'struct').b + 1 else from_json(a, 'struct').b + 2 end FROM testData -- !query analysis -Project [CASE WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a > 5) THEN from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).b WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a > 4) THEN cast((cast(from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).b as double) + cast(1 as double)) as string) ELSE cast((cast(from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).b as double) + cast(2 as double)) as string) END AS CASE WHEN (from_json(a).a > 5) THEN from_json(a).b WHEN (from_json(a).a > 4) THEN (from_json(a).b + 1) ELSE (from_json(a).b + 2) END#x] +Project [CASE WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a > 5) THEN from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).b WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a > 4) THEN cast((cast(from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).b as double) + cast(1 as double)) as string) ELSE cast((cast(from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).b as double) + cast(2 as double)) as string) END AS CASE WHEN (from_json(a).a > 5) THEN from_json(a).b WHEN (from_json(a).a > 4) THEN (from_json(a).b + 1) ELSE (from_json(a).b + 2) END#x] +- SubqueryAlias testdata +- View (`testData`, [a#x, b#x]) +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] @@ -63,7 +63,7 @@ Project [CASE WHEN (from_json(StructField(a,IntegerType,true), StructField(b,Str -- !query SELECT case when from_json(a, 'struct').a > 5 then from_json(b, 'array>')[0].b when from_json(a, 'struct').a > 4 then from_json(b, 'array>')[0].b + 1 else from_json(b, 'array>')[0].b + 2 end FROM testData -- !query analysis -Project [CASE WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a > 5) THEN from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].b WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a > 4) THEN (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].b + 1) ELSE (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].b + 2) END AS CASE WHEN (from_json(a).a > 5) THEN from_json(b)[0].b WHEN (from_json(a).a > 4) THEN (from_json(b)[0].b + 1) ELSE (from_json(b)[0].b + 2) END#x] +Project [CASE WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a > 5) THEN from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].b WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a > 4) THEN (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].b + 1) ELSE (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].b + 2) END AS CASE WHEN (from_json(a).a > 5) THEN from_json(b)[0].b WHEN (from_json(a).a > 4) THEN (from_json(b)[0].b + 1) ELSE (from_json(b)[0].b + 2) END#x] +- SubqueryAlias testdata +- View (`testData`, [a#x, b#x]) +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/timestamp.sql.out index 6ca35b8b141dc..dcfd783b648f8 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/timestamp.sql.out @@ -802,7 +802,7 @@ Project [unix_timestamp(22 05 2020 Friday, dd MM yyyy EEEEE, Some(America/Los_An -- !query select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"t":"26/October/2015"})#x] +Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"t":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp-ansi.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp-ansi.sql.out index e50c860270563..ec227afc87fe1 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp-ansi.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp-ansi.sql.out @@ -745,7 +745,7 @@ Project [unix_timestamp(22 05 2020 Friday, dd MM yyyy EEEEE, Some(America/Los_An -- !query select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(t,TimestampNTZType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"t":"26/October/2015"})#x] +Project [from_json(StructField(t,TimestampNTZType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"t":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp.sql.out index 098abfb3852cf..7475f837250d5 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp.sql.out @@ -805,7 +805,7 @@ Project [unix_timestamp(22 05 2020 Friday, dd MM yyyy EEEEE, Some(America/Los_An -- !query select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(t,TimestampNTZType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"t":"26/October/2015"})#x] +Project [from_json(StructField(t,TimestampNTZType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"t":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/typeCoercion/native/stringCastAndExpressions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/typeCoercion/native/stringCastAndExpressions.sql.out index 009e91f7ffacf..22e60d0606382 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/typeCoercion/native/stringCastAndExpressions.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/typeCoercion/native/stringCastAndExpressions.sql.out @@ -370,7 +370,7 @@ Project [c0#x] -- !query select from_json(a, 'a INT') from t -- !query analysis -Project [from_json(StructField(a,IntegerType,true), a#x, Some(America/Los_Angeles)) AS from_json(a)#x] +Project [from_json(StructField(a,IntegerType,true), a#x, Some(America/Los_Angeles), false) AS from_json(a)#x] +- SubqueryAlias t +- View (`t`, [a#x]) +- Project [cast(a#x as string) AS a#x] diff --git a/sql/core/src/test/resources/sql-tests/inputs/collations.sql b/sql/core/src/test/resources/sql-tests/inputs/collations.sql index 51d8d1be4154c..f3a42fd3e1f12 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/collations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/collations.sql @@ -99,30 +99,324 @@ insert into t4 values('a:1,b:2,c:3', ',', ':'); select str_to_map(text, pairDelim, keyValueDelim) from t4; select str_to_map(text collate utf8_binary, pairDelim collate utf8_lcase, keyValueDelim collate utf8_binary) from t4; select str_to_map(text collate utf8_binary, pairDelim collate utf8_binary, keyValueDelim collate utf8_binary) from t4; +select str_to_map(text collate unicode_ai, pairDelim collate unicode_ai, keyValueDelim collate unicode_ai) from t4; drop table t4; --- create table for split_part -create table t5(str string collate utf8_binary, delimiter string collate utf8_lcase, partNum int) using parquet; - -insert into t5 values('11AB12AB13', 'AB', 2); - -select split_part(str, delimiter, partNum) from t5; -select split_part(str collate utf8_binary, delimiter collate utf8_lcase, partNum) from t5; -select split_part(str collate utf8_binary, delimiter collate utf8_binary, partNum) from t5; +create table t5(s string, utf8_binary string collate utf8_binary, utf8_lcase string collate utf8_lcase) using parquet; +insert into t5 values ('Spark', 'Spark', 'SQL'); +insert into t5 values ('aaAaAAaA', 'aaAaAAaA', 'aaAaAAaA'); +insert into t5 values ('aaAaAAaA', 'aaAaAAaA', 'aaAaaAaA'); +insert into t5 values ('aaAaAAaA', 'aaAaAAaA', 'aaAaaAaAaaAaaAaAaaAaaAaA'); +insert into t5 values ('bbAbaAbA', 'bbAbAAbA', 'a'); +insert into t5 values ('İo', 'İo', 'İo'); +insert into t5 values ('İo', 'İo', 'i̇o'); +insert into t5 values ('efd2', 'efd2', 'efd2'); +insert into t5 values ('Hello, world! Nice day.', 'Hello, world! Nice day.', 'Hello, world! Nice day.'); +insert into t5 values ('Something else. Nothing here.', 'Something else. Nothing here.', 'Something else. Nothing here.'); +insert into t5 values ('kitten', 'kitten', 'sitTing'); +insert into t5 values ('abc', 'abc', 'abc'); +insert into t5 values ('abcdcba', 'abcdcba', 'aBcDCbA'); + +create table t6(ascii long) using parquet; +insert into t6 values (97); +insert into t6 values (66); + +create table t7(ascii double) using parquet; +insert into t7 values (97.52143); +insert into t7 values (66.421); + +create table t8(format string collate utf8_binary, utf8_binary string collate utf8_binary, utf8_lcase string collate utf8_lcase) using parquet; +insert into t8 values ('%s%s', 'abCdE', 'abCdE'); + +create table t9(num long) using parquet; +insert into t9 values (97); +insert into t9 values (66); + +create table t10(utf8_binary string collate utf8_binary, utf8_lcase string collate utf8_lcase) using parquet; +insert into t10 values ('aaAaAAaA', 'aaAaaAaA'); +insert into t10 values ('efd2', 'efd2'); + +-- ConcatWs +select concat_ws(' ', utf8_lcase, utf8_lcase) from t5; +select concat_ws(' ', utf8_binary, utf8_lcase) from t5; +select concat_ws(' ' collate utf8_binary, utf8_binary, 'SQL' collate utf8_lcase) from t5; +select concat_ws(' ' collate utf8_lcase, utf8_binary, 'SQL' collate utf8_lcase) from t5; +select concat_ws(',', utf8_lcase, 'word'), concat_ws(',', utf8_binary, 'word') from t5; +select concat_ws(',', utf8_lcase, 'word' collate utf8_binary), concat_ws(',', utf8_binary, 'word' collate utf8_lcase) from t5; + +-- Elt +select elt(2, s, utf8_binary) from t5; +select elt(2, utf8_binary, utf8_lcase, s) from t5; +select elt(1, utf8_binary collate utf8_binary, utf8_lcase collate utf8_lcase) from t5; +select elt(1, utf8_binary collate utf8_binary, utf8_lcase collate utf8_binary) from t5; +select elt(1, utf8_binary collate utf8_binary, utf8_lcase) from t5; +select elt(1, utf8_binary, 'word'), elt(1, utf8_lcase, 'word') from t5; +select elt(1, utf8_binary, 'word' collate utf8_lcase), elt(1, utf8_lcase, 'word' collate utf8_binary) from t5; + +-- SplitPart +select split_part(utf8_binary, utf8_lcase, 3) from t5; +select split_part(s, utf8_binary, 1) from t5; +select split_part(utf8_binary collate utf8_binary, s collate utf8_lcase, 1) from t5; +select split_part(utf8_binary, utf8_lcase collate utf8_binary, 2) from t5; +select split_part(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 2) from t5; +select split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5; +select split_part(utf8_binary, 'a', 3), split_part(utf8_lcase, 'a', 3) from t5; +select split_part(utf8_binary, 'a' collate utf8_lcase, 3), split_part(utf8_lcase, 'a' collate utf8_binary, 3) from t5; + +-- Contains +select contains(utf8_binary, utf8_lcase) from t5; +select contains(s, utf8_binary) from t5; +select contains(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; +select contains(utf8_binary, utf8_lcase collate utf8_binary) from t5; +select contains(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; +select contains(utf8_binary, 'a'), contains(utf8_lcase, 'a') from t5; +select contains(utf8_binary, 'AaAA' collate utf8_lcase), contains(utf8_lcase, 'AAa' collate utf8_binary) from t5; + +-- SubstringIndex +select substring_index(utf8_binary, utf8_lcase, 2) from t5; +select substring_index(s, utf8_binary,1) from t5; +select substring_index(utf8_binary collate utf8_binary, s collate utf8_lcase, 3) from t5; +select substring_index(utf8_binary, utf8_lcase collate utf8_binary, 2) from t5; +select substring_index(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 2) from t5; +select substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5; +select substring_index(utf8_binary, 'a', 2), substring_index(utf8_lcase, 'a', 2) from t5; +select substring_index(utf8_binary, 'AaAA' collate utf8_lcase, 2), substring_index(utf8_lcase, 'AAa' collate utf8_binary, 2) from t5; + +-- StringInStr +select instr(utf8_binary, utf8_lcase) from t5; +select instr(s, utf8_binary) from t5; +select instr(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; +select instr(utf8_binary, utf8_lcase collate utf8_binary) from t5; +select instr(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; +select instr(utf8_binary, 'a'), instr(utf8_lcase, 'a') from t5; +select instr(utf8_binary, 'AaAA' collate utf8_lcase), instr(utf8_lcase, 'AAa' collate utf8_binary) from t5; + +-- FindInSet +select find_in_set(utf8_binary, utf8_lcase) from t5; +select find_in_set(s, utf8_binary) from t5; +select find_in_set(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; +select find_in_set(utf8_binary, utf8_lcase collate utf8_binary) from t5; +select find_in_set(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select find_in_set(utf8_binary, 'aaAaaAaA,i̇o'), find_in_set(utf8_lcase, 'aaAaaAaA,i̇o') from t5; +select find_in_set(utf8_binary, 'aaAaaAaA,i̇o' collate utf8_lcase), find_in_set(utf8_lcase, 'aaAaaAaA,i̇o' collate utf8_binary) from t5; + +-- StartsWith +select startswith(utf8_binary, utf8_lcase) from t5; +select startswith(s, utf8_binary) from t5; +select startswith(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; +select startswith(utf8_binary, utf8_lcase collate utf8_binary) from t5; +select startswith(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; +select startswith(utf8_binary, 'aaAaaAaA'), startswith(utf8_lcase, 'aaAaaAaA') from t5; +select startswith(utf8_binary, 'aaAaaAaA' collate utf8_lcase), startswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5; + +-- StringTranslate +select translate(utf8_lcase, utf8_lcase, '12345') from t5; +select translate(utf8_binary, utf8_lcase, '12345') from t5; +select translate(utf8_binary, 'aBc' collate utf8_lcase, '12345' collate utf8_binary) from t5; +select translate(utf8_binary, 'SQL' collate utf8_lcase, '12345' collate utf8_lcase) from t5; +select translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai) from t5; +select translate(utf8_lcase, 'aaAaaAaA', '12345'), translate(utf8_binary, 'aaAaaAaA', '12345') from t5; +select translate(utf8_lcase, 'aBc' collate utf8_binary, '12345'), translate(utf8_binary, 'aBc' collate utf8_lcase, '12345') from t5; + +-- Replace +select replace(utf8_binary, utf8_lcase, 'abc') from t5; +select replace(s, utf8_binary, 'abc') from t5; +select replace(utf8_binary collate utf8_binary, s collate utf8_lcase, 'abc') from t5; +select replace(utf8_binary, utf8_lcase collate utf8_binary, 'abc') from t5; +select replace(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 'abc') from t5; +select replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc') from t5; +select replace(utf8_binary, 'aaAaaAaA', 'abc'), replace(utf8_lcase, 'aaAaaAaA', 'abc') from t5; +select replace(utf8_binary, 'aaAaaAaA' collate utf8_lcase, 'abc'), replace(utf8_lcase, 'aaAaaAaA' collate utf8_binary, 'abc') from t5; + +-- EndsWith +select endswith(utf8_binary, utf8_lcase) from t5; +select endswith(s, utf8_binary) from t5; +select endswith(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; +select endswith(utf8_binary, utf8_lcase collate utf8_binary) from t5; +select endswith(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; +select endswith(utf8_binary, 'aaAaaAaA'), endswith(utf8_lcase, 'aaAaaAaA') from t5; +select endswith(utf8_binary, 'aaAaaAaA' collate utf8_lcase), endswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5; + +-- StringRepeat +select repeat(utf8_binary, 3), repeat(utf8_lcase, 2) from t5; +select repeat(utf8_binary collate utf8_lcase, 3), repeat(utf8_lcase collate utf8_binary, 2) from t5; + +-- Ascii & UnBase64 string expressions +select ascii(utf8_binary), ascii(utf8_lcase) from t5; +select ascii(utf8_binary collate utf8_lcase), ascii(utf8_lcase collate utf8_binary) from t5; +select unbase64(utf8_binary), unbase64(utf8_lcase) from t10; +select unbase64(utf8_binary collate utf8_lcase), unbase64(utf8_lcase collate utf8_binary) from t10; + +-- Chr +select chr(ascii) from t6; + +-- Base64, Decode +select base64(utf8_binary), base64(utf8_lcase) from t5; +select base64(utf8_binary collate utf8_lcase), base64(utf8_lcase collate utf8_binary) from t5; +select decode(encode(utf8_binary, 'utf-8'), 'utf-8'), decode(encode(utf8_lcase, 'utf-8'), 'utf-8') from t5; +select decode(encode(utf8_binary collate utf8_lcase, 'utf-8'), 'utf-8'), decode(encode(utf8_lcase collate utf8_binary, 'utf-8'), 'utf-8') from t5; + +-- FormatNumber +select format_number(ascii, '###.###') from t7; +select format_number(ascii, '###.###' collate utf8_lcase) from t7; + +-- Encode, ToBinary +select encode(utf8_binary, 'utf-8'), encode(utf8_lcase, 'utf-8') from t5; +select encode(utf8_binary collate utf8_lcase, 'utf-8'), encode(utf8_lcase collate utf8_binary, 'utf-8') from t5; +select to_binary(utf8_binary, 'utf-8'), to_binary(utf8_lcase, 'utf-8') from t5; +select to_binary(utf8_binary collate utf8_lcase, 'utf-8'), to_binary(utf8_lcase collate utf8_binary, 'utf-8') from t5; + +-- Sentences +select sentences(utf8_binary), sentences(utf8_lcase) from t5; +select sentences(utf8_binary collate utf8_lcase), sentences(utf8_lcase collate utf8_binary) from t5; + +-- Upper +select upper(utf8_binary), upper(utf8_lcase) from t5; +select upper(utf8_binary collate utf8_lcase), upper(utf8_lcase collate utf8_binary) from t5; + +-- Lower +select lower(utf8_binary), lower(utf8_lcase) from t5; +select lower(utf8_binary collate utf8_lcase), lower(utf8_lcase collate utf8_binary) from t5; + +-- InitCap +select initcap(utf8_binary), initcap(utf8_lcase) from t5; +select initcap(utf8_binary collate utf8_lcase), initcap(utf8_lcase collate utf8_binary) from t5; + +-- Overlay +select overlay(utf8_binary, utf8_lcase, 2) from t5; +select overlay(s, utf8_binary,1) from t5; +select overlay(utf8_binary collate utf8_binary, s collate utf8_lcase, 3) from t5; +select overlay(utf8_binary, utf8_lcase collate utf8_binary, 2) from t5; +select overlay(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 2) from t5; +select overlay(utf8_binary, 'a', 2), overlay(utf8_lcase, 'a', 2) from t5; +select overlay(utf8_binary, 'AaAA' collate utf8_lcase, 2), overlay(utf8_lcase, 'AAa' collate utf8_binary, 2) from t5; + +-- FormatString +select format_string(format, utf8_binary, utf8_lcase) from t8; +select format_string(format collate utf8_lcase, utf8_lcase, utf8_binary collate utf8_lcase, 3), format_string(format, utf8_lcase collate utf8_binary, utf8_binary) from t8; +select format_string(format, utf8_binary, utf8_lcase) from t8; + +-- SoundEx +select soundex(utf8_binary), soundex(utf8_lcase) from t5; +select soundex(utf8_binary collate utf8_lcase), soundex(utf8_lcase collate utf8_binary) from t5; + +-- Length, BitLength & OctetLength +select length(utf8_binary), length(utf8_lcase) from t5; +select length(utf8_binary collate utf8_lcase), length(utf8_lcase collate utf8_binary) from t5; +select bit_length(utf8_binary), bit_length(utf8_lcase) from t5; +select bit_length(utf8_binary collate utf8_lcase), bit_length(utf8_lcase collate utf8_binary) from t5; +select octet_length(utf8_binary), octet_length(utf8_lcase) from t5; +select octet_length(utf8_binary collate utf8_lcase), octet_length(utf8_lcase collate utf8_binary) from t5; + +-- Luhncheck +select luhn_check(num) from t9; + +-- Levenshtein +select levenshtein(utf8_binary, utf8_lcase) from t5; +select levenshtein(s, utf8_binary) from t5; +select levenshtein(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; +select levenshtein(utf8_binary, utf8_lcase collate utf8_binary) from t5; +select levenshtein(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select levenshtein(utf8_binary, 'a'), levenshtein(utf8_lcase, 'a') from t5; +select levenshtein(utf8_binary, 'AaAA' collate utf8_lcase, 3), levenshtein(utf8_lcase, 'AAa' collate utf8_binary, 4) from t5; + +-- IsValidUTF8 +select is_valid_utf8(utf8_binary), is_valid_utf8(utf8_lcase) from t5; +select is_valid_utf8(utf8_binary collate utf8_lcase), is_valid_utf8(utf8_lcase collate utf8_binary) from t5; + +-- MakeValidUTF8 +select make_valid_utf8(utf8_binary), make_valid_utf8(utf8_lcase) from t5; +select make_valid_utf8(utf8_binary collate utf8_lcase), make_valid_utf8(utf8_lcase collate utf8_binary) from t5; + +-- ValidateUTF8 +select validate_utf8(utf8_binary), validate_utf8(utf8_lcase) from t5; +select validate_utf8(utf8_binary collate utf8_lcase), validate_utf8(utf8_lcase collate utf8_binary) from t5; + +-- TryValidateUTF8 +select try_validate_utf8(utf8_binary), try_validate_utf8(utf8_lcase) from t5; +select try_validate_utf8(utf8_binary collate utf8_lcase), try_validate_utf8(utf8_lcase collate utf8_binary) from t5; + +-- Left/Right/Substr +select substr(utf8_binary, 2, 2), substr(utf8_lcase, 2, 2) from t5; +select substr(utf8_binary collate utf8_lcase, 2, 2), substr(utf8_lcase collate utf8_binary, 2, 2) from t5; +select right(utf8_binary, 2), right(utf8_lcase, 2) from t5; +select right(utf8_binary collate utf8_lcase, 2), right(utf8_lcase collate utf8_binary, 2) from t5; +select left(utf8_binary, '2' collate utf8_lcase), left(utf8_lcase, 2) from t5; +select left(utf8_binary collate utf8_lcase, 2), left(utf8_lcase collate utf8_binary, 2) from t5; + +-- StringRPad +select rpad(utf8_binary, 8, utf8_lcase) from t5; +select rpad(s, 8, utf8_binary) from t5; +select rpad(utf8_binary collate utf8_binary, 8, s collate utf8_lcase) from t5; +select rpad(utf8_binary, 8, utf8_lcase collate utf8_binary) from t5; +select rpad(utf8_binary collate utf8_lcase, 8, utf8_lcase collate utf8_lcase) from t5; +select rpad(utf8_binary, 8, 'a'), rpad(utf8_lcase, 8, 'a') from t5; +select rpad(utf8_binary, 8, 'AaAA' collate utf8_lcase), rpad(utf8_lcase, 8, 'AAa' collate utf8_binary) from t5; + +-- StringLPad +select lpad(utf8_binary, 8, utf8_lcase) from t5; +select lpad(s, 8, utf8_binary) from t5; +select lpad(utf8_binary collate utf8_binary, 8, s collate utf8_lcase) from t5; +select lpad(utf8_binary, 8, utf8_lcase collate utf8_binary) from t5; +select lpad(utf8_binary collate utf8_lcase, 8, utf8_lcase collate utf8_lcase) from t5; +select lpad(utf8_binary, 8, 'a'), lpad(utf8_lcase, 8, 'a') from t5; +select lpad(utf8_binary, 8, 'AaAA' collate utf8_lcase), lpad(utf8_lcase, 8, 'AAa' collate utf8_binary) from t5; + +-- Locate +select locate(utf8_binary, utf8_lcase) from t5; +select locate(s, utf8_binary) from t5; +select locate(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; +select locate(utf8_binary, utf8_lcase collate utf8_binary) from t5; +select locate(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 3) from t5; +select locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3) from t5; +select locate(utf8_binary, 'a'), locate(utf8_lcase, 'a') from t5; +select locate(utf8_binary, 'AaAA' collate utf8_lcase, 4), locate(utf8_lcase, 'AAa' collate utf8_binary, 4) from t5; + +-- StringTrim +select TRIM(utf8_binary, utf8_lcase) from t5; +select TRIM(s, utf8_binary) from t5; +select TRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; +select TRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; +select TRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; +select TRIM('ABc', utf8_binary), TRIM('ABc', utf8_lcase) from t5; +select TRIM('ABc' collate utf8_lcase, utf8_binary), TRIM('AAa' collate utf8_binary, utf8_lcase) from t5; +-- StringTrimBoth +select BTRIM(utf8_binary, utf8_lcase) from t5; +select BTRIM(s, utf8_binary) from t5; +select BTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; +select BTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; +select BTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; +select BTRIM('ABc', utf8_binary), BTRIM('ABc', utf8_lcase) from t5; +select BTRIM('ABc' collate utf8_lcase, utf8_binary), BTRIM('AAa' collate utf8_binary, utf8_lcase) from t5; +-- StringTrimLeft +select LTRIM(utf8_binary, utf8_lcase) from t5; +select LTRIM(s, utf8_binary) from t5; +select LTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; +select LTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; +select LTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; +select LTRIM('ABc', utf8_binary), LTRIM('ABc', utf8_lcase) from t5; +select LTRIM('ABc' collate utf8_lcase, utf8_binary), LTRIM('AAa' collate utf8_binary, utf8_lcase) from t5; +-- StringTrimRight +select RTRIM(utf8_binary, utf8_lcase) from t5; +select RTRIM(s, utf8_binary) from t5; +select RTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; +select RTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; +select RTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; +select RTRIM('ABc', utf8_binary), RTRIM('ABc', utf8_lcase) from t5; +select RTRIM('ABc' collate utf8_lcase, utf8_binary), RTRIM('AAa' collate utf8_binary, utf8_lcase) from t5; drop table t5; - --- create table for levenshtein -create table t6 (utf8_binary string collate utf8_binary, utf8_lcase string collate utf8_lcase, threshold int) using parquet; - -insert into t6 values('kitten', 'sitting', 2); - -select levenshtein(utf8_binary, utf8_lcase) from t6; -select levenshtein(utf8_binary collate utf8_binary, utf8_lcase collate utf8_lcase) from t6; -select levenshtein(utf8_binary collate utf8_binary, utf8_lcase collate utf8_binary) from t6; -select levenshtein(utf8_binary, utf8_lcase, threshold) from t6; -select levenshtein(utf8_binary collate utf8_binary, utf8_lcase collate utf8_lcase, threshold) from t6; -select levenshtein(utf8_binary collate utf8_binary, utf8_lcase collate utf8_binary, threshold) from t6; - drop table t6; +drop table t7; +drop table t8; +drop table t9; +drop table t10; diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql b/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql index 8bff1f109aa65..e3cef9207d20f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql @@ -531,6 +531,27 @@ select * from t1 join lateral (select t4.c1 as t from t4 where t1.c1 = t4.c1)) as foo order by foo.t limit 5); + +select 1 +from t1 as t_outer +left join + lateral( + select b1,b2 + from + ( + select + t2.c1 as b1, + 1 as b2 + from t2 + union + select t_outer.c1 as b1, + null as b2 + ) as t_inner + where (t_inner.b1 < t_outer.c2 or t_inner.b1 is null) + and t_inner.b1 = t_outer.c1 + order by t_inner.b1,t_inner.b2 desc limit 1 + ) as lateral_table; + -- clean up DROP VIEW t1; DROP VIEW t2; diff --git a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql new file mode 100644 index 0000000000000..49a72137ee047 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql @@ -0,0 +1,192 @@ +-- Prepare some test data. +-------------------------- +drop table if exists t; +create table t(x int, y string) using csv; +insert into t values (0, 'abc'), (1, 'def'); + +drop table if exists other; +create table other(a int, b int) using json; +insert into other values (1, 1), (1, 2), (2, 4); + +drop table if exists st; +create table st(x int, col struct) using parquet; +insert into st values (1, (2, 3)); + +-- SELECT operators: positive tests. +--------------------------------------- + +-- Selecting a constant. +table t +|> select 1 as x; + +-- Selecting attributes. +table t +|> select x, y; + +-- Chained pipe SELECT operators. +table t +|> select x, y +|> select x + length(y) as z; + +-- Using the VALUES list as the source relation. +values (0), (1) tab(col) +|> select col * 2 as result; + +-- Using a table subquery as the source relation. +(select * from t union all select * from t) +|> select x + length(y) as result; + +-- Enclosing the result of a pipe SELECT operation in a table subquery. +(table t + |> select x, y + |> select x) +union all +select x from t where x < 1; + +-- Selecting struct fields. +(select col from st) +|> select col.i1; + +table st +|> select st.col.i1; + +-- Expression subqueries in the pipe operator SELECT list. +table t +|> select (select a from other where x = a limit 1) as result; + +-- Pipe operator SELECT inside expression subqueries. +select (values (0) tab(col) |> select col) as result; + +-- Aggregations are allowed within expression subqueries in the pipe operator SELECT list as long as +-- no aggregate functions exist in the top-level select list. +table t +|> select (select any_value(a) from other where x = a limit 1) as result; + +-- Lateral column aliases in the pipe operator SELECT list. +table t +|> select x + length(x) as z, z + 1 as plus_one; + +-- Window functions are allowed in the pipe operator SELECT list. +table t +|> select first_value(x) over (partition by y) as result; + +select 1 x, 2 y, 3 z +|> select 1 + sum(x) over (), + avg(y) over (), + x, + avg(x+1) over (partition by y order by z) AS a2 +|> select a2; + +table t +|> select x, count(*) over () +|> select x; + +-- DISTINCT is supported. +table t +|> select distinct x, y; + +-- SELECT * is supported. +table t +|> select *; + +table t +|> select * except (y); + +-- Hints are supported. +table t +|> select /*+ repartition(3) */ *; + +table t +|> select /*+ repartition(3) */ distinct x; + +table t +|> select /*+ repartition(3) */ all x; + +-- SELECT operators: negative tests. +--------------------------------------- + +-- Aggregate functions are not allowed in the pipe operator SELECT list. +table t +|> select sum(x) as result; + +table t +|> select y, length(y) + sum(x) as result; + +-- WHERE operators: positive tests. +----------------------------------- + +-- Filtering with a constant predicate. +table t +|> where true; + +-- Filtering with a predicate based on attributes from the input relation. +table t +|> where x + length(y) < 4; + +-- Two consecutive filters are allowed. +table t +|> where x + length(y) < 4 +|> where x + length(y) < 3; + +-- It is possible to use the WHERE operator instead of the HAVING clause when processing the result +-- of aggregations. For example, this WHERE operator is equivalent to the normal SQL "HAVING x = 1". +(select x, sum(length(y)) as sum_len from t group by x) +|> where x = 1; + +-- Filtering by referring to the table or table subquery alias. +table t +|> where t.x = 1; + +table t +|> where spark_catalog.default.t.x = 1; + +-- Filtering using struct fields. +(select col from st) +|> where col.i1 = 1; + +table st +|> where st.col.i1 = 2; + +-- Expression subqueries in the WHERE clause. +table t +|> where exists (select a from other where x = a limit 1); + +-- Aggregations are allowed within expression subqueries in the pipe operator WHERE clause as long +-- no aggregate functions exist in the top-level expression predicate. +table t +|> where (select any_value(a) from other where x = a limit 1) = 1; + +-- WHERE operators: negative tests. +----------------------------------- + +-- Aggregate functions are not allowed in the top-level WHERE predicate. +-- (Note: to implement this behavior, perform the aggregation first separately and then add a +-- pipe-operator WHERE clause referring to the result of aggregate expression(s) therein). +table t +|> where sum(x) = 1; + +table t +|> where y = 'abc' or length(y) + sum(x) = 1; + +-- Window functions are not allowed in the WHERE clause (pipe operators or otherwise). +table t +|> where first_value(x) over (partition by y) = 1; + +select * from t where first_value(x) over (partition by y) = 1; + +-- Pipe operators may only refer to attributes produced as output from the directly-preceding +-- pipe operator, not from earlier ones. +table t +|> select x, length(y) as z +|> where x + length(y) < 4; + +-- If the WHERE clause wants to filter rows produced by an aggregation, it is not valid to try to +-- refer to the aggregate functions directly; it is necessary to use aliases instead. +(select x, sum(length(y)) as sum_len from t group by x) +|> where sum(length(y)) = 3; + +-- Cleanup. +----------- +drop table t; +drop table other; +drop table st; diff --git a/sql/core/src/test/resources/sql-tests/inputs/random.sql b/sql/core/src/test/resources/sql-tests/inputs/random.sql index a1aae7b8759dc..a71b0293295fc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/random.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/random.sql @@ -14,4 +14,44 @@ SELECT randn(NULL); SELECT randn(cast(NULL AS long)); -- randn unsupported data type -SELECT rand('1') +SELECT rand('1'); + +-- The uniform random number generation function supports generating random numbers within a +-- specified range. We use a seed of zero for these queries to keep tests deterministic. +SELECT uniform(0, 1, 0) AS result; +SELECT uniform(0, 10, 0) AS result; +SELECT uniform(0L, 10L, 0) AS result; +SELECT uniform(0, 10L, 0) AS result; +SELECT uniform(0, 10S, 0) AS result; +SELECT uniform(10, 20, 0) AS result; +SELECT uniform(10.0F, 20.0F, 0) AS result; +SELECT uniform(10.0D, 20.0D, CAST(3 / 7 AS LONG)) AS result; +SELECT uniform(10, 20.0F, 0) AS result; +SELECT uniform(10, 20, 0) AS result FROM VALUES (0), (1), (2) tab(col); +SELECT uniform(10, 20.0F) IS NOT NULL AS result; +-- Negative test cases for the uniform random number generator. +SELECT uniform(NULL, 1, 0) AS result; +SELECT uniform(0, NULL, 0) AS result; +SELECT uniform(0, 1, NULL) AS result; +SELECT uniform(10, 20, col) AS result FROM VALUES (0), (1), (2) tab(col); +SELECT uniform(col, 10, 0) AS result FROM VALUES (0), (1), (2) tab(col); +SELECT uniform(10) AS result; +SELECT uniform(10, 20, 30, 40) AS result; + +-- The randstr random string generation function supports generating random strings within a +-- specified length. We use a seed of zero for these queries to keep tests deterministic. +SELECT randstr(1, 0) AS result; +SELECT randstr(5, 0) AS result; +SELECT randstr(10, 0) AS result; +SELECT randstr(10S, 0) AS result; +SELECT randstr(10, 0) AS result FROM VALUES (0), (1), (2) tab(col); +SELECT randstr(10) IS NOT NULL AS result; +-- Negative test cases for the randstr random number generator. +SELECT randstr(10L, 0) AS result; +SELECT randstr(10.0F, 0) AS result; +SELECT randstr(10.0D, 0) AS result; +SELECT randstr(NULL, 0) AS result; +SELECT randstr(0, NULL) AS result; +SELECT randstr(col, 0) AS result FROM VALUES (0), (1), (2) tab(col); +SELECT randstr(10, col) AS result FROM VALUES (0), (1), (2) tab(col); +SELECT randstr(10, 0, 1) AS result; diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out index d17d87900fc71..7394e428091c7 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out @@ -151,27 +151,9 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select sort_array(array('b', 'd'), cast(NULL as boolean)) -- !query schema -struct<> +struct> -- !query output -org.apache.spark.sql.catalyst.ExtendedAnalysisException -{ - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - "sqlState" : "42K09", - "messageParameters" : { - "inputSql" : "\"CAST(NULL AS BOOLEAN)\"", - "inputType" : "\"BOOLEAN\"", - "paramIndex" : "second", - "requiredType" : "\"BOOLEAN\"", - "sqlExpr" : "\"sort_array(array(b, d), CAST(NULL AS BOOLEAN))\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 57, - "fragment" : "sort_array(array('b', 'd'), cast(NULL as boolean))" - } ] -} +NULL -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out index 7dd7180165f2b..0dbdf1d9975c9 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out @@ -9,7 +9,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'1.23'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -34,7 +33,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'1.23'", "sourceType" : "\"STRING\"", "targetType" : "\"BIGINT\"" @@ -59,7 +57,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'-4.56'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -84,7 +81,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'-4.56'", "sourceType" : "\"STRING\"", "targetType" : "\"BIGINT\"" @@ -109,7 +105,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'abc'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -134,7 +129,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'abc'", "sourceType" : "\"STRING\"", "targetType" : "\"BIGINT\"" @@ -159,7 +153,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'abc'", "sourceType" : "\"STRING\"", "targetType" : "\"FLOAT\"" @@ -184,7 +177,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'abc'", "sourceType" : "\"STRING\"", "targetType" : "\"DOUBLE\"" @@ -209,7 +201,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'1234567890123'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -234,7 +225,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'12345678901234567890123'", "sourceType" : "\"STRING\"", "targetType" : "\"BIGINT\"" @@ -259,7 +249,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "''", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -284,7 +273,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "''", "sourceType" : "\"STRING\"", "targetType" : "\"BIGINT\"" @@ -309,7 +297,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "''", "sourceType" : "\"STRING\"", "targetType" : "\"FLOAT\"" @@ -334,7 +321,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "''", "sourceType" : "\"STRING\"", "targetType" : "\"DOUBLE\"" @@ -375,7 +361,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'123.a'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -400,7 +385,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'123.a'", "sourceType" : "\"STRING\"", "targetType" : "\"BIGINT\"" @@ -425,7 +409,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'123.a'", "sourceType" : "\"STRING\"", "targetType" : "\"FLOAT\"" @@ -450,7 +433,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'123.a'", "sourceType" : "\"STRING\"", "targetType" : "\"DOUBLE\"" @@ -483,7 +465,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'-2147483649'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -516,7 +497,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'2147483648'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -549,7 +529,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'-9223372036854775809'", "sourceType" : "\"STRING\"", "targetType" : "\"BIGINT\"" @@ -582,7 +561,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'9223372036854775808'", "sourceType" : "\"STRING\"", "targetType" : "\"BIGINT\"" @@ -973,7 +951,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'1中文'", "sourceType" : "\"STRING\"", "targetType" : "\"TINYINT\"" @@ -998,7 +975,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'1中文'", "sourceType" : "\"STRING\"", "targetType" : "\"SMALLINT\"" @@ -1023,7 +999,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'1中文'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -1048,7 +1023,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'中文1'", "sourceType" : "\"STRING\"", "targetType" : "\"BIGINT\"" @@ -1073,7 +1047,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'1中文'", "sourceType" : "\"STRING\"", "targetType" : "\"BIGINT\"" @@ -1116,7 +1089,6 @@ org.apache.spark.SparkRuntimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'\t\n xyz \t\r'", "sourceType" : "\"STRING\"", "targetType" : "\"BOOLEAN\"" @@ -1174,7 +1146,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'xyz'", "sourceType" : "\"STRING\"", "targetType" : "\"DECIMAL(4,2)\"" @@ -1207,7 +1178,6 @@ org.apache.spark.SparkDateTimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'a'", "sourceType" : "\"STRING\"", "targetType" : "\"DATE\"" @@ -1240,7 +1210,6 @@ org.apache.spark.SparkDateTimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'a'", "sourceType" : "\"STRING\"", "targetType" : "\"TIMESTAMP\"" @@ -1273,7 +1242,6 @@ org.apache.spark.SparkDateTimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'a'", "sourceType" : "\"STRING\"", "targetType" : "\"TIMESTAMP_NTZ\"" @@ -1298,7 +1266,6 @@ org.apache.spark.SparkDateTimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "Infinity", "sourceType" : "\"DOUBLE\"", "targetType" : "\"TIMESTAMP\"" @@ -1323,7 +1290,6 @@ org.apache.spark.SparkDateTimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "Infinity", "sourceType" : "\"DOUBLE\"", "targetType" : "\"TIMESTAMP\"" @@ -1380,7 +1346,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"INTERVAL HOUR TO SECOND\"", "targetType" : "\"SMALLINT\"", "value" : "INTERVAL '23:59:59' HOUR TO SECOND" @@ -1414,7 +1379,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"INTERVAL MONTH\"", "targetType" : "\"TINYINT\"", "value" : "INTERVAL '-1000' MONTH" @@ -1432,7 +1396,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"INTERVAL SECOND\"", "targetType" : "\"SMALLINT\"", "value" : "INTERVAL '1000000' SECOND" @@ -1522,7 +1485,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"INT\"", "targetType" : "\"INTERVAL YEAR\"", "value" : "2147483647" @@ -1540,7 +1502,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"BIGINT\"", "targetType" : "\"INTERVAL DAY\"", "value" : "-9223372036854775808L" @@ -1671,7 +1632,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'1.23'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -1696,7 +1656,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'abc'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -1721,7 +1680,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'12345678901234567890123'", "sourceType" : "\"STRING\"", "targetType" : "\"BIGINT\"" @@ -1746,7 +1704,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "''", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -1779,7 +1736,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'123.a'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -1927,7 +1883,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'1.23'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -1952,7 +1907,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"BIGINT\"", "targetType" : "\"INT\"", "value" : "2147483648L" @@ -1970,7 +1924,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"BIGINT\"", "targetType" : "\"INT\"", "value" : "2147483648L" diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/conditional-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/conditional-functions.sql.out index 26293cad10ce4..aa8a600f87560 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/conditional-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/conditional-functions.sql.out @@ -145,7 +145,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'abc'", "sourceType" : "\"STRING\"", "targetType" : "\"BIGINT\"" @@ -181,7 +180,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'abc'", "sourceType" : "\"STRING\"", "targetType" : "\"BIGINT\"" diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out index b1b26b2f74ad1..67cd23faf2556 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out @@ -268,7 +268,8 @@ struct<> -- !query output org.apache.spark.SparkIllegalArgumentException { - "errorClass" : "_LEGACY_ERROR_TEMP_3209", + "errorClass" : "ILLEGAL_DAY_OF_WEEK", + "sqlState" : "22009", "messageParameters" : { "string" : "xx" } @@ -309,7 +310,6 @@ org.apache.spark.SparkDateTimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'xx'", "sourceType" : "\"STRING\"", "targetType" : "\"DATE\"" @@ -468,7 +468,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'1.2'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -643,7 +642,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'1.2'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/datetime-parsing-invalid.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/datetime-parsing-invalid.sql.out index 514a0c6ae7d31..0708a523900ff 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/datetime-parsing-invalid.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/datetime-parsing-invalid.sql.out @@ -427,7 +427,6 @@ org.apache.spark.SparkDateTimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'Unparseable'", "sourceType" : "\"STRING\"", "targetType" : "\"TIMESTAMP\"" @@ -452,7 +451,6 @@ org.apache.spark.SparkDateTimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'Unparseable'", "sourceType" : "\"STRING\"", "targetType" : "\"DATE\"" diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 4e220ba9885c1..b2f85835eb0df 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -130,7 +130,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'a'", "sourceType" : "\"STRING\"", "targetType" : "\"DOUBLE\"" @@ -155,7 +154,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'a'", "sourceType" : "\"STRING\"", "targetType" : "\"DOUBLE\"" @@ -180,7 +178,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'a'", "sourceType" : "\"STRING\"", "targetType" : "\"DOUBLE\"" @@ -205,7 +202,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'a'", "sourceType" : "\"STRING\"", "targetType" : "\"DOUBLE\"" @@ -246,7 +242,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'a'", "sourceType" : "\"STRING\"", "targetType" : "\"DOUBLE\"" @@ -271,7 +266,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'a'", "sourceType" : "\"STRING\"", "targetType" : "\"DOUBLE\"" @@ -1261,9 +1255,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR` when cast to interval day to hour: 20 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "20 15:40:32.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR`", + "typeName" : "interval day to hour" }, "queryContext" : [ { "objectType" : "", @@ -1282,9 +1280,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h:m`, `INTERVAL [+|-]'[+|-]d h:m' DAY TO MINUTE` when cast to interval day to minute: 20 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "20 15:40:32.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h:m`, `INTERVAL [+|-]'[+|-]d h:m' DAY TO MINUTE`", + "typeName" : "interval day to minute" }, "queryContext" : [ { "objectType" : "", @@ -1303,9 +1305,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE` when cast to interval hour to minute: 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "15:40:32.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE`", + "typeName" : "interval hour to minute" }, "queryContext" : [ { "objectType" : "", @@ -1324,9 +1330,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND` when cast to interval hour to second: 15:40.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "15:40.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND`", + "typeName" : "interval hour to second" }, "queryContext" : [ { "objectType" : "", @@ -1345,9 +1355,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND` when cast to interval hour to second: 15:40, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "15:40", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND`", + "typeName" : "interval hour to second" }, "queryContext" : [ { "objectType" : "", @@ -1366,9 +1380,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND` when cast to interval minute to second: 20 40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "20 40:32.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND`", + "typeName" : "interval minute to second" }, "queryContext" : [ { "objectType" : "", @@ -1797,9 +1815,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.INTERVAL_PARSING", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Error parsing interval year-month string: integer overflow" + "input" : "178956970-8", + "interval" : "year-month" }, "queryContext" : [ { "objectType" : "", @@ -1945,7 +1965,6 @@ org.apache.spark.SparkDateTimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'4 11:11'", "sourceType" : "\"STRING\"", "targetType" : "\"TIMESTAMP\"" @@ -1970,7 +1989,6 @@ org.apache.spark.SparkDateTimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'4 12:12:12'", "sourceType" : "\"STRING\"", "targetType" : "\"TIMESTAMP\"" @@ -2051,7 +2069,6 @@ org.apache.spark.SparkDateTimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'1'", "sourceType" : "\"STRING\"", "targetType" : "\"TIMESTAMP\"" @@ -2076,7 +2093,6 @@ org.apache.spark.SparkDateTimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'1'", "sourceType" : "\"STRING\"", "targetType" : "\"TIMESTAMP\"" @@ -2348,9 +2364,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match year-month format of `[+|-]y-m`, `INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH` when cast to interval year to month: -\t2-2\t" + "input" : "-\t2-2\t", + "intervalStr" : "year-month", + "supportedFormat" : "`[+|-]y-m`, `INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH`", + "typeName" : "interval year to month" }, "queryContext" : [ { "objectType" : "", @@ -2377,9 +2397,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h:m:s.n`, `INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND` when cast to interval day to second: \n-\t10\t 12:34:46.789\t, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "\n-\t10\t 12:34:46.789\t", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h:m:s.n`, `INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND`", + "typeName" : "interval day to second" }, "queryContext" : [ { "objectType" : "", @@ -2866,7 +2890,7 @@ SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1.0D struct<> -- !query output java.lang.ArithmeticException -not in range +rounded value is out of range for input 2.147483648E9 and rounding mode HALF_UP -- !query @@ -2946,7 +2970,7 @@ SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1.0D struct<> -- !query output java.lang.ArithmeticException -not in range +rounded value is out of range for input 9.223372036854776E18 and rounding mode HALF_UP -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out index 5735e5eef68e7..7c694503056ab 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out @@ -32,6 +32,7 @@ BUCKETS false BY false BYTE false CACHE false +CALL true CALLED false CASCADE false CASE true @@ -141,6 +142,7 @@ HAVING true HOUR false HOURS false IDENTIFIER false +IDENTITY false IF false IGNORE false ILIKE false @@ -148,6 +150,7 @@ IMMEDIATE false IMPORT false IN true INCLUDE false +INCREMENT false INDEX false INDEXES false INNER true @@ -163,6 +166,7 @@ INTO true INVOKER false IS true ITEMS false +ITERATE false JOIN true KEYS false LANGUAGE false @@ -170,6 +174,7 @@ LAST false LATERAL true LAZY false LEADING true +LEAVE false LEFT true LIKE false LIMIT false @@ -249,6 +254,7 @@ REFERENCES true REFRESH false RENAME false REPAIR false +REPEAT false REPEATABLE false REPLACE false RESET false @@ -334,6 +340,7 @@ UNKNOWN true UNLOCK false UNPIVOT false UNSET false +UNTIL false UPDATE false USE false USER true @@ -372,6 +379,7 @@ ANY AS AUTHORIZATION BOTH +CALL CASE CAST CHECK diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out index e2abcb099130a..fb60a920040e6 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out @@ -881,7 +881,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'invalid'", "sourceType" : "\"STRING\"", "targetType" : "\"DOUBLE\"" diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index 93077221ee5a8..cf1bce3c0e504 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -101,7 +101,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'a'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -142,7 +141,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'a'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -507,7 +505,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'invalid_length'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -532,7 +529,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'invalid_length'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out index 92da0a490ff81..c1330c620acfb 100644 --- a/sql/core/src/test/resources/sql-tests/results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -151,27 +151,9 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query select sort_array(array('b', 'd'), cast(NULL as boolean)) -- !query schema -struct<> +struct> -- !query output -org.apache.spark.sql.catalyst.ExtendedAnalysisException -{ - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - "sqlState" : "42K09", - "messageParameters" : { - "inputSql" : "\"CAST(NULL AS BOOLEAN)\"", - "inputType" : "\"BOOLEAN\"", - "paramIndex" : "second", - "requiredType" : "\"BOOLEAN\"", - "sqlExpr" : "\"sort_array(array(b, d), CAST(NULL AS BOOLEAN))\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 57, - "fragment" : "sort_array(array('b', 'd'), cast(NULL as boolean))" - } ] -} +NULL -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/cast.sql.out index 6f74c63da3543..738697c638832 100644 --- a/sql/core/src/test/resources/sql-tests/results/cast.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cast.sql.out @@ -642,7 +642,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"INTERVAL HOUR TO SECOND\"", "targetType" : "\"SMALLINT\"", "value" : "INTERVAL '23:59:59' HOUR TO SECOND" @@ -676,7 +675,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"INTERVAL MONTH\"", "targetType" : "\"TINYINT\"", "value" : "INTERVAL '-1000' MONTH" @@ -694,7 +692,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"INTERVAL SECOND\"", "targetType" : "\"SMALLINT\"", "value" : "INTERVAL '1000000' SECOND" @@ -784,7 +781,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"INT\"", "targetType" : "\"INTERVAL YEAR\"", "value" : "2147483647" @@ -802,7 +798,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"BIGINT\"", "targetType" : "\"INTERVAL DAY\"", "value" : "-9223372036854775808L" diff --git a/sql/core/src/test/resources/sql-tests/results/collations.sql.out b/sql/core/src/test/resources/sql-tests/results/collations.sql.out index d8f8d0676baed..9d29a46e5a0ef 100644 --- a/sql/core/src/test/resources/sql-tests/results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/collations.sql.out @@ -480,6 +480,32 @@ struct +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(text, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"str_to_map(collate(text, unicode_ai), collate(pairDelim, unicode_ai), collate(keyValueDelim, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 106, + "fragment" : "str_to_map(text collate unicode_ai, pairDelim collate unicode_ai, keyValueDelim collate unicode_ai)" + } ] +} + + -- !query drop table t4 -- !query schema @@ -489,23 +515,3504 @@ struct<> -- !query -create table t5(str string collate utf8_binary, delimiter string collate utf8_lcase, partNum int) using parquet +create table t5(s string, utf8_binary string collate utf8_binary, utf8_lcase string collate utf8_lcase) using parquet +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t5 values ('Spark', 'Spark', 'SQL') +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t5 values ('aaAaAAaA', 'aaAaAAaA', 'aaAaAAaA') +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t5 values ('aaAaAAaA', 'aaAaAAaA', 'aaAaaAaA') +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t5 values ('aaAaAAaA', 'aaAaAAaA', 'aaAaaAaAaaAaaAaAaaAaaAaA') +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t5 values ('bbAbaAbA', 'bbAbAAbA', 'a') +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t5 values ('İo', 'İo', 'İo') +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t5 values ('İo', 'İo', 'i̇o') +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t5 values ('efd2', 'efd2', 'efd2') +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t5 values ('Hello, world! Nice day.', 'Hello, world! Nice day.', 'Hello, world! Nice day.') +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t5 values ('Something else. Nothing here.', 'Something else. Nothing here.', 'Something else. Nothing here.') +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t5 values ('kitten', 'kitten', 'sitTing') +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t5 values ('abc', 'abc', 'abc') +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t5 values ('abcdcba', 'abcdcba', 'aBcDCbA') +-- !query schema +struct<> +-- !query output + + + +-- !query +create table t6(ascii long) using parquet +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t6 values (97) +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t6 values (66) +-- !query schema +struct<> +-- !query output + + + +-- !query +create table t7(ascii double) using parquet +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t7 values (97.52143) +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t7 values (66.421) +-- !query schema +struct<> +-- !query output + + + +-- !query +create table t8(format string collate utf8_binary, utf8_binary string collate utf8_binary, utf8_lcase string collate utf8_lcase) using parquet +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t8 values ('%s%s', 'abCdE', 'abCdE') +-- !query schema +struct<> +-- !query output + + + +-- !query +create table t9(num long) using parquet +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t9 values (97) +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t9 values (66) +-- !query schema +struct<> +-- !query output + + + +-- !query +create table t10(utf8_binary string collate utf8_binary, utf8_lcase string collate utf8_lcase) using parquet +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t10 values ('aaAaAAaA', 'aaAaaAaA') +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t10 values ('efd2', 'efd2') +-- !query schema +struct<> +-- !query output + + + +-- !query +select concat_ws(' ', utf8_lcase, utf8_lcase) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +SQL SQL +Something else. Nothing here. Something else. Nothing here. +a a +aBcDCbA aBcDCbA +aaAaAAaA aaAaAAaA +aaAaaAaA aaAaaAaA +aaAaaAaAaaAaaAaAaaAaaAaA aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +efd2 efd2 +i̇o i̇o +sitTing sitTing +İo İo + + +-- !query +select concat_ws(' ', utf8_binary, utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select concat_ws(' ' collate utf8_binary, utf8_binary, 'SQL' collate utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select concat_ws(' ' collate utf8_lcase, utf8_binary, 'SQL' collate utf8_lcase) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. SQL +Something else. Nothing here. SQL +Spark SQL +aaAaAAaA SQL +aaAaAAaA SQL +aaAaAAaA SQL +abc SQL +abcdcba SQL +bbAbAAbA SQL +efd2 SQL +kitten SQL +İo SQL +İo SQL + + +-- !query +select concat_ws(',', utf8_lcase, 'word'), concat_ws(',', utf8_binary, 'word') from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day.,word Hello, world! Nice day.,word +SQL,word Spark,word +Something else. Nothing here.,word Something else. Nothing here.,word +a,word bbAbAAbA,word +aBcDCbA,word abcdcba,word +aaAaAAaA,word aaAaAAaA,word +aaAaaAaA,word aaAaAAaA,word +aaAaaAaAaaAaaAaAaaAaaAaA,word aaAaAAaA,word +abc,word abc,word +efd2,word efd2,word +i̇o,word İo,word +sitTing,word kitten,word +İo,word İo,word + + +-- !query +select concat_ws(',', utf8_lcase, 'word' collate utf8_binary), concat_ws(',', utf8_binary, 'word' collate utf8_lcase) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day.,word Hello, world! Nice day.,word +SQL,word Spark,word +Something else. Nothing here.,word Something else. Nothing here.,word +a,word bbAbAAbA,word +aBcDCbA,word abcdcba,word +aaAaAAaA,word aaAaAAaA,word +aaAaaAaA,word aaAaAAaA,word +aaAaaAaAaaAaaAaAaaAaaAaA,word aaAaAAaA,word +abc,word abc,word +efd2,word efd2,word +i̇o,word İo,word +sitTing,word kitten,word +İo,word İo,word + + +-- !query +select elt(2, s, utf8_binary) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. +Something else. Nothing here. +Spark +aaAaAAaA +aaAaAAaA +aaAaAAaA +abc +abcdcba +bbAbAAbA +efd2 +kitten +İo +İo + + +-- !query +select elt(2, utf8_binary, utf8_lcase, s) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select elt(1, utf8_binary collate utf8_binary, utf8_lcase collate utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select elt(1, utf8_binary collate utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. +Something else. Nothing here. +Spark +aaAaAAaA +aaAaAAaA +aaAaAAaA +abc +abcdcba +bbAbAAbA +efd2 +kitten +İo +İo + + +-- !query +select elt(1, utf8_binary collate utf8_binary, utf8_lcase) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. +Something else. Nothing here. +Spark +aaAaAAaA +aaAaAAaA +aaAaAAaA +abc +abcdcba +bbAbAAbA +efd2 +kitten +İo +İo + + +-- !query +select elt(1, utf8_binary, 'word'), elt(1, utf8_lcase, 'word') from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select elt(1, utf8_binary, 'word' collate utf8_lcase), elt(1, utf8_lcase, 'word' collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select split_part(utf8_binary, utf8_lcase, 3) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select split_part(s, utf8_binary, 1) from t5 +-- !query schema +struct +-- !query output + + + + + + + + + + + + +bbAbaAbA + + +-- !query +select split_part(utf8_binary collate utf8_binary, s collate utf8_lcase, 1) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select split_part(utf8_binary, utf8_lcase collate utf8_binary, 2) from t5 +-- !query schema +struct +-- !query output + + + +-- !query +select split_part(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 2) from t5 +-- !query schema +struct +-- !query output + + + + + + + + + + + + +b + + +-- !query +select split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"split_part(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 83, + "fragment" : "split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2)" + } ] +} + + +-- !query +select split_part(utf8_binary, 'a', 3), split_part(utf8_lcase, 'a', 3) from t5 +-- !query schema +struct +-- !query output + + + + + + + + + + +A +A +A + + +-- !query +select split_part(utf8_binary, 'a' collate utf8_lcase, 3), split_part(utf8_lcase, 'a' collate utf8_binary, 3) from t5 +-- !query schema +struct +-- !query output + + + + + + + + + + + A + A + A + + +-- !query +select contains(utf8_binary, utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select contains(s, utf8_binary) from t5 +-- !query schema +struct +-- !query output +false +true +true +true +true +true +true +true +true +true +true +true +true + + +-- !query +select contains(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select contains(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +false +false +false +false +false +false +false +true +true +true +true +true +true + + +-- !query +select contains(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query schema +struct +-- !query output +false +false +false +true +true +true +true +true +true +true +true +true +true + + +-- !query +select contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"contains(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 78, + "fragment" : "contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + +-- !query +select contains(utf8_binary, 'a'), contains(utf8_lcase, 'a') from t5 +-- !query schema +struct +-- !query output +false false +false false +false false +false false +false false +false true +true false +true true +true true +true true +true true +true true +true true + + +-- !query +select contains(utf8_binary, 'AaAA' collate utf8_lcase), contains(utf8_lcase, 'AAa' collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +false false +false false +false false +false false +false false +false false +false false +false false +false false +false false +true false +true false +true true + + +-- !query +select substring_index(utf8_binary, utf8_lcase, 2) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select substring_index(s, utf8_binary,1) from t5 +-- !query schema +struct +-- !query output + + + + + + + + + + + + +bbAbaAbA + + +-- !query +select substring_index(utf8_binary collate utf8_binary, s collate utf8_lcase, 3) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select substring_index(utf8_binary, utf8_lcase collate utf8_binary, 2) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. +Something else. Nothing here. +Spark +aaAaAAaA +aaAaAAaA +aaAaAAaA +abc +abcdcba +bbAbAAbA +efd2 +kitten +İo +İo + + +-- !query +select substring_index(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 2) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. +Something else. Nothing here. +Spark +aaAaAAaA +aaAaAAaA +aaAaAAaA +abc +abcdcba +bbAb +efd2 +kitten +İo +İo + + +-- !query +select substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"substring_index(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 88, + "fragment" : "substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2)" + } ] +} + + +-- !query +select substring_index(utf8_binary, 'a', 2), substring_index(utf8_lcase, 'a', 2) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +a a +a a +a a +abc abc +abcdcb aBcDCb +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select substring_index(utf8_binary, 'AaAA' collate utf8_lcase, 2), substring_index(utf8_lcase, 'AAa' collate utf8_binary, 2) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +a aaAaAAaA +a aaAaaAaA +a aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select instr(utf8_binary, utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select instr(s, utf8_binary) from t5 +-- !query schema +struct +-- !query output +0 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 + + +-- !query +select instr(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select instr(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +0 +0 +0 +0 +0 +0 +0 +1 +1 +1 +1 +1 +1 + + +-- !query +select instr(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query schema +struct +-- !query output +0 +0 +0 +1 +1 +1 +1 +1 +1 +1 +1 +1 +3 + + +-- !query +select instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"instr(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + +-- !query +select instr(utf8_binary, 'a'), instr(utf8_lcase, 'a') from t5 +-- !query schema +struct +-- !query output +0 0 +0 0 +0 0 +0 0 +0 0 +0 1 +1 1 +1 1 +1 1 +1 1 +1 1 +21 21 +3 0 + + +-- !query +select instr(utf8_binary, 'AaAA' collate utf8_lcase), instr(utf8_lcase, 'AAa' collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +1 0 +1 0 +1 5 + + +-- !query +select find_in_set(utf8_binary, utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select find_in_set(s, utf8_binary) from t5 +-- !query schema +struct +-- !query output +0 +0 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 + + +-- !query +select find_in_set(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select find_in_set(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +0 +0 +0 +0 +0 +0 +0 +0 +1 +1 +1 +1 +1 + + +-- !query +select find_in_set(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query schema +struct +-- !query output +0 +0 +0 +0 +0 +1 +1 +1 +1 +1 +1 +1 +1 + + +-- !query +select find_in_set(utf8_binary, 'aaAaaAaA,i̇o'), find_in_set(utf8_lcase, 'aaAaaAaA,i̇o') from t5 +-- !query schema +struct +-- !query output +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 1 +0 1 +0 2 +0 2 + + +-- !query +select find_in_set(utf8_binary, 'aaAaaAaA,i̇o' collate utf8_lcase), find_in_set(utf8_lcase, 'aaAaaAaA,i̇o' collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +1 0 +1 0 +1 1 +2 0 +2 2 + + +-- !query +select startswith(utf8_binary, utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select startswith(s, utf8_binary) from t5 +-- !query schema +struct +-- !query output +false +true +true +true +true +true +true +true +true +true +true +true +true + + +-- !query +select startswith(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select startswith(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +false +false +false +false +false +false +false +true +true +true +true +true +true + + +-- !query +select startswith(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query schema +struct +-- !query output +false +false +false +false +true +true +true +true +true +true +true +true +true + + +-- !query +select startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"startswith(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 80, + "fragment" : "startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + +-- !query +select startswith(utf8_binary, 'aaAaaAaA'), startswith(utf8_lcase, 'aaAaaAaA') from t5 +-- !query schema +struct +-- !query output +false false +false false +false false +false false +false false +false false +false false +false false +false false +false false +false true +false true +false true + + +-- !query +select startswith(utf8_binary, 'aaAaaAaA' collate utf8_lcase), startswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +false false +false false +false false +false false +false false +false false +false false +false false +false false +false false +true false +true true +true true + + +-- !query +select translate(utf8_lcase, utf8_lcase, '12345') from t5 +-- !query schema +struct +-- !query output +1 +11111111 +11111111 +111111111111111111111111 +12 +123 +123 +123 +12332 +12335532 +1234 +1234321 +123454142544 + + +-- !query +select translate(utf8_binary, utf8_lcase, '12345') from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select translate(utf8_binary, 'aBc' collate utf8_lcase, '12345' collate utf8_binary) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string collate UTF8_LCASE`, `string`" + } +} + + +-- !query +select translate(utf8_binary, 'SQL' collate utf8_lcase, '12345' collate utf8_lcase) from t5 +-- !query schema +struct +-- !query output +1omething e31e. Nothing here. +1park +He33o, wor3d! Nice day. +aaAaAAaA +aaAaAAaA +aaAaAAaA +abc +abcdcba +bbAbAAbA +efd2 +kitten +İo +İo + + +-- !query +select translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"utf8_binary\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"translate(utf8_binary, collate(SQL, unicode_ai), collate(12345, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 83, + "fragment" : "translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai)" + } ] +} + + +-- !query +select translate(utf8_lcase, 'aaAaaAaA', '12345'), translate(utf8_binary, 'aaAaaAaA', '12345') from t5 +-- !query schema +struct +-- !query output +1 bb3b33b3 +11111111 11313313 +11111111 11313313 +111111111111111111111111 11313313 +1BcDCb1 1bcdcb1 +1bc 1bc +Hello, world! Nice d1y. Hello, world! Nice d1y. +SQL Sp1rk +Something else. Nothing here. Something else. Nothing here. +efd2 efd2 +i̇o İo +sitTing kitten +İo İo + + +-- !query +select translate(utf8_lcase, 'aBc' collate utf8_binary, '12345'), translate(utf8_binary, 'aBc' collate utf8_lcase, '12345') from t5 +-- !query schema +struct +-- !query output +1 22121121 +11A11A1A 11111111 +11A11A1A11A11A1A11A11A1A 11111111 +11A1AA1A 11111111 +123DCbA 123d321 +1b3 123 +Hello, world! Ni3e d1y. Hello, world! Ni3e d1y. +SQL Sp1rk +Something else. Nothing here. Something else. Nothing here. +efd2 efd2 +i̇o İo +sitTing kitten +İo İo + + +-- !query +select replace(utf8_binary, utf8_lcase, 'abc') from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select replace(s, utf8_binary, 'abc') from t5 +-- !query schema +struct +-- !query output +abc +abc +abc +abc +abc +abc +abc +abc +abc +abc +abc +abc +bbAbaAbA + + +-- !query +select replace(utf8_binary collate utf8_binary, s collate utf8_lcase, 'abc') from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select replace(utf8_binary, utf8_lcase collate utf8_binary, 'abc') from t5 +-- !query schema +struct +-- !query output +Spark +aaAaAAaA +aaAaAAaA +abc +abc +abc +abc +abc +abc +abcdcba +bbAbAAbA +kitten +İo + + +-- !query +select replace(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 'abc') from t5 +-- !query schema +struct +-- !query output +Spark +aaAaAAaA +abc +abc +abc +abc +abc +abc +abc +abc +abc +bbabcbabcabcbabc +kitten + + +-- !query +select replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc') from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"replace(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), abc)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 84, + "fragment" : "replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc')" + } ] +} + + +-- !query +select replace(utf8_binary, 'aaAaaAaA', 'abc'), replace(utf8_lcase, 'aaAaaAaA', 'abc') from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA abc +aaAaAAaA abc +aaAaAAaA abcabcabc +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select replace(utf8_binary, 'aaAaaAaA' collate utf8_lcase, 'abc'), replace(utf8_lcase, 'aaAaaAaA' collate utf8_binary, 'abc') from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +abc aaAaAAaA +abc abc +abc abc +abc abcabcabc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select endswith(utf8_binary, utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select endswith(s, utf8_binary) from t5 +-- !query schema +struct +-- !query output +false +true +true +true +true +true +true +true +true +true +true +true +true + + +-- !query +select endswith(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select endswith(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +false +false +false +false +false +false +false +true +true +true +true +true +true + + +-- !query +select endswith(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query schema +struct +-- !query output +false +false +false +true +true +true +true +true +true +true +true +true +true + + +-- !query +select endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"endswith(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 78, + "fragment" : "endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + +-- !query +select endswith(utf8_binary, 'aaAaaAaA'), endswith(utf8_lcase, 'aaAaaAaA') from t5 +-- !query schema +struct +-- !query output +false false +false false +false false +false false +false false +false false +false false +false false +false false +false false +false true +false true +false true + + +-- !query +select endswith(utf8_binary, 'aaAaaAaA' collate utf8_lcase), endswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +false false +false false +false false +false false +false false +false false +false false +false false +false false +false false +true false +true true +true true + + +-- !query +select repeat(utf8_binary, 3), repeat(utf8_lcase, 2) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day.Hello, world! Nice day.Hello, world! Nice day. Hello, world! Nice day.Hello, world! Nice day. +Something else. Nothing here.Something else. Nothing here.Something else. Nothing here. Something else. Nothing here.Something else. Nothing here. +SparkSparkSpark SQLSQL +aaAaAAaAaaAaAAaAaaAaAAaA aaAaAAaAaaAaAAaA +aaAaAAaAaaAaAAaAaaAaAAaA aaAaaAaAaaAaaAaA +aaAaAAaAaaAaAAaAaaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaAaaAaaAaAaaAaaAaAaaAaaAaA +abcabcabc abcabc +abcdcbaabcdcbaabcdcba aBcDCbAaBcDCbA +bbAbAAbAbbAbAAbAbbAbAAbA aa +efd2efd2efd2 efd2efd2 +kittenkittenkitten sitTingsitTing +İoİoİo i̇oi̇o +İoİoİo İoİo + + +-- !query +select repeat(utf8_binary collate utf8_lcase, 3), repeat(utf8_lcase collate utf8_binary, 2) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day.Hello, world! Nice day.Hello, world! Nice day. Hello, world! Nice day.Hello, world! Nice day. +Something else. Nothing here.Something else. Nothing here.Something else. Nothing here. Something else. Nothing here.Something else. Nothing here. +SparkSparkSpark SQLSQL +aaAaAAaAaaAaAAaAaaAaAAaA aaAaAAaAaaAaAAaA +aaAaAAaAaaAaAAaAaaAaAAaA aaAaaAaAaaAaaAaA +aaAaAAaAaaAaAAaAaaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaAaaAaaAaAaaAaaAaAaaAaaAaA +abcabcabc abcabc +abcdcbaabcdcbaabcdcba aBcDCbAaBcDCbA +bbAbAAbAbbAbAAbAbbAbAAbA aa +efd2efd2efd2 efd2efd2 +kittenkittenkitten sitTingsitTing +İoİoİo i̇oi̇o +İoİoİo İoİo + + +-- !query +select ascii(utf8_binary), ascii(utf8_lcase) from t5 +-- !query schema +struct +-- !query output +101 101 +107 115 +304 105 +304 304 +72 72 +83 83 +83 83 +97 97 +97 97 +97 97 +97 97 +97 97 +98 97 + + +-- !query +select ascii(utf8_binary collate utf8_lcase), ascii(utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +101 101 +107 115 +304 105 +304 304 +72 72 +83 83 +83 83 +97 97 +97 97 +97 97 +97 97 +97 97 +98 97 + + +-- !query +select unbase64(utf8_binary), unbase64(utf8_lcase) from t10 +-- !query schema +struct +-- !query output +i�� i�h� +y�v y�v + + +-- !query +select unbase64(utf8_binary collate utf8_lcase), unbase64(utf8_lcase collate utf8_binary) from t10 +-- !query schema +struct +-- !query output +i�� i�h� +y�v y�v + + +-- !query +select chr(ascii) from t6 +-- !query schema +struct +-- !query output +B +a + + +-- !query +select base64(utf8_binary), base64(utf8_lcase) from t5 +-- !query schema +struct +-- !query output +SGVsbG8sIHdvcmxkISBOaWNlIGRheS4= SGVsbG8sIHdvcmxkISBOaWNlIGRheS4= +U29tZXRoaW5nIGVsc2UuIE5vdGhpbmcgaGVyZS4= U29tZXRoaW5nIGVsc2UuIE5vdGhpbmcgaGVyZS4= +U3Bhcms= U1FM +YWFBYUFBYUE= YWFBYUFBYUE= +YWFBYUFBYUE= YWFBYWFBYUE= +YWFBYUFBYUE= YWFBYWFBYUFhYUFhYUFhQWFhQWFhQWFB +YWJj YWJj +YWJjZGNiYQ== YUJjRENiQQ== +YmJBYkFBYkE= YQ== +ZWZkMg== ZWZkMg== +a2l0dGVu c2l0VGluZw== +xLBv acyHbw== +xLBv xLBv + + +-- !query +select base64(utf8_binary collate utf8_lcase), base64(utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +SGVsbG8sIHdvcmxkISBOaWNlIGRheS4= SGVsbG8sIHdvcmxkISBOaWNlIGRheS4= +U29tZXRoaW5nIGVsc2UuIE5vdGhpbmcgaGVyZS4= U29tZXRoaW5nIGVsc2UuIE5vdGhpbmcgaGVyZS4= +U3Bhcms= U1FM +YWFBYUFBYUE= YWFBYUFBYUE= +YWFBYUFBYUE= YWFBYWFBYUE= +YWFBYUFBYUE= YWFBYWFBYUFhYUFhYUFhQWFhQWFhQWFB +YWJj YWJj +YWJjZGNiYQ== YUJjRENiQQ== +YmJBYkFBYkE= YQ== +ZWZkMg== ZWZkMg== +a2l0dGVu c2l0VGluZw== +xLBv acyHbw== +xLBv xLBv + + +-- !query +select decode(encode(utf8_binary, 'utf-8'), 'utf-8'), decode(encode(utf8_lcase, 'utf-8'), 'utf-8') from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select decode(encode(utf8_binary collate utf8_lcase, 'utf-8'), 'utf-8'), decode(encode(utf8_lcase collate utf8_binary, 'utf-8'), 'utf-8') from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select format_number(ascii, '###.###') from t7 +-- !query schema +struct +-- !query output +66.421 +97.521 + + +-- !query +select format_number(ascii, '###.###' collate utf8_lcase) from t7 +-- !query schema +struct +-- !query output +66.421 +97.521 + + +-- !query +select encode(utf8_binary, 'utf-8'), encode(utf8_lcase, 'utf-8') from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select encode(utf8_binary collate utf8_lcase, 'utf-8'), encode(utf8_lcase collate utf8_binary, 'utf-8') from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select to_binary(utf8_binary, 'utf-8'), to_binary(utf8_lcase, 'utf-8') from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select to_binary(utf8_binary collate utf8_lcase, 'utf-8'), to_binary(utf8_lcase collate utf8_binary, 'utf-8') from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select sentences(utf8_binary), sentences(utf8_lcase) from t5 +-- !query schema +struct>,sentences(utf8_lcase, , ):array>> +-- !query output +[["Hello","world"],["Nice","day"]] [["Hello","world"],["Nice","day"]] +[["Something","else"],["Nothing","here"]] [["Something","else"],["Nothing","here"]] +[["Spark"]] [["SQL"]] +[["aaAaAAaA"]] [["aaAaAAaA"]] +[["aaAaAAaA"]] [["aaAaaAaA"]] +[["aaAaAAaA"]] [["aaAaaAaAaaAaaAaAaaAaaAaA"]] +[["abc"]] [["abc"]] +[["abcdcba"]] [["aBcDCbA"]] +[["bbAbAAbA"]] [["a"]] +[["efd2"]] [["efd2"]] +[["kitten"]] [["sitTing"]] +[["İo"]] [["i̇o"]] +[["İo"]] [["İo"]] + + +-- !query +select sentences(utf8_binary collate utf8_lcase), sentences(utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct>,sentences(collate(utf8_lcase, utf8_binary), , ):array>> +-- !query output +[["Hello","world"],["Nice","day"]] [["Hello","world"],["Nice","day"]] +[["Something","else"],["Nothing","here"]] [["Something","else"],["Nothing","here"]] +[["Spark"]] [["SQL"]] +[["aaAaAAaA"]] [["aaAaAAaA"]] +[["aaAaAAaA"]] [["aaAaaAaA"]] +[["aaAaAAaA"]] [["aaAaaAaAaaAaaAaAaaAaaAaA"]] +[["abc"]] [["abc"]] +[["abcdcba"]] [["aBcDCbA"]] +[["bbAbAAbA"]] [["a"]] +[["efd2"]] [["efd2"]] +[["kitten"]] [["sitTing"]] +[["İo"]] [["i̇o"]] +[["İo"]] [["İo"]] + + +-- !query +select upper(utf8_binary), upper(utf8_lcase) from t5 +-- !query schema +struct +-- !query output +AAAAAAAA AAAAAAAA +AAAAAAAA AAAAAAAA +AAAAAAAA AAAAAAAAAAAAAAAAAAAAAAAA +ABC ABC +ABCDCBA ABCDCBA +BBABAABA A +EFD2 EFD2 +HELLO, WORLD! NICE DAY. HELLO, WORLD! NICE DAY. +KITTEN SITTING +SOMETHING ELSE. NOTHING HERE. SOMETHING ELSE. NOTHING HERE. +SPARK SQL +İO İO +İO İO + + +-- !query +select upper(utf8_binary collate utf8_lcase), upper(utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +AAAAAAAA AAAAAAAA +AAAAAAAA AAAAAAAA +AAAAAAAA AAAAAAAAAAAAAAAAAAAAAAAA +ABC ABC +ABCDCBA ABCDCBA +BBABAABA A +EFD2 EFD2 +HELLO, WORLD! NICE DAY. HELLO, WORLD! NICE DAY. +KITTEN SITTING +SOMETHING ELSE. NOTHING HERE. SOMETHING ELSE. NOTHING HERE. +SPARK SQL +İO İO +İO İO + + +-- !query +select lower(utf8_binary), lower(utf8_lcase) from t5 +-- !query schema +struct +-- !query output +aaaaaaaa aaaaaaaa +aaaaaaaa aaaaaaaa +aaaaaaaa aaaaaaaaaaaaaaaaaaaaaaaa +abc abc +abcdcba abcdcba +bbabaaba a +efd2 efd2 +hello, world! nice day. hello, world! nice day. +i̇o i̇o +i̇o i̇o +kitten sitting +something else. nothing here. something else. nothing here. +spark sql + + +-- !query +select lower(utf8_binary collate utf8_lcase), lower(utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +aaaaaaaa aaaaaaaa +aaaaaaaa aaaaaaaa +aaaaaaaa aaaaaaaaaaaaaaaaaaaaaaaa +abc abc +abcdcba abcdcba +bbabaaba a +efd2 efd2 +hello, world! nice day. hello, world! nice day. +i̇o i̇o +i̇o i̇o +kitten sitting +something else. nothing here. something else. nothing here. +spark sql + + +-- !query +select initcap(utf8_binary), initcap(utf8_lcase) from t5 +-- !query schema +struct +-- !query output +Aaaaaaaa Aaaaaaaa +Aaaaaaaa Aaaaaaaa +Aaaaaaaa Aaaaaaaaaaaaaaaaaaaaaaaa +Abc Abc +Abcdcba Abcdcba +Bbabaaba A +Efd2 Efd2 +Hello, World! Nice Day. Hello, World! Nice Day. +Kitten Sitting +Something Else. Nothing Here. Something Else. Nothing Here. +Spark Sql +İo İo +İo İo + + +-- !query +select initcap(utf8_binary collate utf8_lcase), initcap(utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +Aaaaaaaa Aaaaaaaa +Aaaaaaaa Aaaaaaaa +Aaaaaaaa Aaaaaaaaaaaaaaaaaaaaaaaa +Abc Abc +Abcdcba Abcdcba +Bbabaaba A +Efd2 Efd2 +Hello, World! Nice Day. Hello, World! Nice Day. +Kitten Sitting +Something Else. Nothing Here. Something Else. Nothing Here. +Spark Sql +İo İo +İo İo + + +-- !query +select overlay(utf8_binary, utf8_lcase, 2) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select overlay(s, utf8_binary,1) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. +Something else. Nothing here. +Spark +aaAaAAaA +aaAaAAaA +aaAaAAaA +abc +abcdcba +bbAbAAbA +efd2 +kitten +İo +İo + + +-- !query +select overlay(utf8_binary collate utf8_binary, s collate utf8_lcase, 3) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select overlay(utf8_binary, utf8_lcase collate utf8_binary, 2) from t5 +-- !query schema +struct +-- !query output +HHello, world! Nice day. +SSQLk +SSomething else. Nothing here. +aaBcDCbA +aaaAaAAaA +aaaAaaAaA +aaaAaaAaAaaAaaAaAaaAaaAaA +aabc +baAbAAbA +eefd2 +ksitTing +İi̇o +İİo + + +-- !query +select overlay(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 2) from t5 +-- !query schema +struct +-- !query output +HHello, world! Nice day. +SSQLk +SSomething else. Nothing here. +aaBcDCbA +aaaAaAAaA +aaaAaaAaA +aaaAaaAaAaaAaaAaAaaAaaAaA +aabc +baAbAAbA +eefd2 +ksitTing +İi̇o +İİo + + +-- !query +select overlay(utf8_binary, 'a', 2), overlay(utf8_lcase, 'a', 2) from t5 +-- !query schema +struct +-- !query output +Hallo, world! Nice day. Hallo, world! Nice day. +Saark SaL +Samething else. Nothing here. Samething else. Nothing here. +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaA +aac aac +aacdcba aacDCbA +baAbAAbA aa +ead2 ead2 +katten satTing +İa iao +İa İa + + +-- !query +select overlay(utf8_binary, 'AaAA' collate utf8_lcase, 2), overlay(utf8_lcase, 'AAa' collate utf8_binary, 2) from t5 +-- !query schema +struct +-- !query output +HAaAA, world! Nice day. HAAao, world! Nice day. +SAaAA SAAa +SAaAAhing else. Nothing here. SAAathing else. Nothing here. +aAaAA aAAa +aAaAAAaA aAAaAAaA +aAaAAAaA aAAaaAaA +aAaAAAaA aAAaaAaAaaAaaAaAaaAaaAaA +aAaAAba aAAaCbA +bAaAAAbA aAAa +eAaAA eAAa +kAaAAn sAAaing +İAaAA iAAa +İAaAA İAAa + + +-- !query +select format_string(format, utf8_binary, utf8_lcase) from t8 +-- !query schema +struct +-- !query output +abCdEabCdE + + +-- !query +select format_string(format collate utf8_lcase, utf8_lcase, utf8_binary collate utf8_lcase, 3), format_string(format, utf8_lcase collate utf8_binary, utf8_binary) from t8 +-- !query schema +struct +-- !query output +abCdEabCdE abCdEabCdE + + +-- !query +select format_string(format, utf8_binary, utf8_lcase) from t8 +-- !query schema +struct +-- !query output +abCdEabCdE + + +-- !query +select soundex(utf8_binary), soundex(utf8_lcase) from t5 +-- !query schema +struct +-- !query output +A000 A000 +A000 A000 +A000 A000 +A120 A120 +A123 A123 +B110 A000 +E130 E130 +H464 H464 +K350 S352 +S162 S400 +S535 S535 +İo I000 +İo İo + + +-- !query +select soundex(utf8_binary collate utf8_lcase), soundex(utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +A000 A000 +A000 A000 +A000 A000 +A120 A120 +A123 A123 +B110 A000 +E130 E130 +H464 H464 +K350 S352 +S162 S400 +S535 S535 +İo I000 +İo İo + + +-- !query +select length(utf8_binary), length(utf8_lcase) from t5 +-- !query schema +struct +-- !query output +2 2 +2 3 +23 23 +29 29 +3 3 +4 4 +5 3 +6 7 +7 7 +8 1 +8 24 +8 8 +8 8 + + +-- !query +select length(utf8_binary collate utf8_lcase), length(utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +2 2 +2 3 +23 23 +29 29 +3 3 +4 4 +5 3 +6 7 +7 7 +8 1 +8 24 +8 8 +8 8 + + +-- !query +select bit_length(utf8_binary), bit_length(utf8_lcase) from t5 +-- !query schema +struct +-- !query output +184 184 +232 232 +24 24 +24 24 +24 32 +32 32 +40 24 +48 56 +56 56 +64 192 +64 64 +64 64 +64 8 + + +-- !query +select bit_length(utf8_binary collate utf8_lcase), bit_length(utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +184 184 +232 232 +24 24 +24 24 +24 32 +32 32 +40 24 +48 56 +56 56 +64 192 +64 64 +64 64 +64 8 + + +-- !query +select octet_length(utf8_binary), octet_length(utf8_lcase) from t5 +-- !query schema +struct +-- !query output +23 23 +29 29 +3 3 +3 3 +3 4 +4 4 +5 3 +6 7 +7 7 +8 1 +8 24 +8 8 +8 8 + + +-- !query +select octet_length(utf8_binary collate utf8_lcase), octet_length(utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +23 23 +29 29 +3 3 +3 3 +3 4 +4 4 +5 3 +6 7 +7 7 +8 1 +8 24 +8 8 +8 8 + + +-- !query +select luhn_check(num) from t9 +-- !query schema +struct +-- !query output +false +false + + +-- !query +select levenshtein(utf8_binary, utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select levenshtein(s, utf8_binary) from t5 +-- !query schema +struct +-- !query output +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +1 + + +-- !query +select levenshtein(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select levenshtein(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +0 +0 +0 +0 +0 +0 +1 +16 +2 +4 +4 +4 +8 + + +-- !query +select levenshtein(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query schema +struct +-- !query output +0 +0 +0 +0 +0 +0 +1 +16 +2 +4 +4 +4 +8 + + +-- !query +select levenshtein(utf8_binary, 'a'), levenshtein(utf8_lcase, 'a') from t5 +-- !query schema +struct +-- !query output +2 2 +2 2 +2 3 +22 22 +29 29 +4 3 +4 4 +6 6 +6 7 +7 23 +7 7 +7 7 +8 0 + + +-- !query +select levenshtein(utf8_binary, 'AaAA' collate utf8_lcase, 3), levenshtein(utf8_lcase, 'AAa' collate utf8_binary, 4) from t5 +-- !query schema +struct +-- !query output +-1 -1 +-1 -1 +-1 -1 +-1 -1 +-1 -1 +-1 -1 +-1 -1 +-1 2 +-1 3 +-1 3 +-1 3 +-1 4 +3 3 + + +-- !query +select is_valid_utf8(utf8_binary), is_valid_utf8(utf8_lcase) from t5 +-- !query schema +struct +-- !query output +true true +true true +true true +true true +true true +true true +true true +true true +true true +true true +true true +true true +true true + + +-- !query +select is_valid_utf8(utf8_binary collate utf8_lcase), is_valid_utf8(utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +true true +true true +true true +true true +true true +true true +true true +true true +true true +true true +true true +true true +true true + + +-- !query +select make_valid_utf8(utf8_binary), make_valid_utf8(utf8_lcase) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select make_valid_utf8(utf8_binary collate utf8_lcase), make_valid_utf8(utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select validate_utf8(utf8_binary), validate_utf8(utf8_lcase) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select validate_utf8(utf8_binary collate utf8_lcase), validate_utf8(utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select try_validate_utf8(utf8_binary), try_validate_utf8(utf8_lcase) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select try_validate_utf8(utf8_binary collate utf8_lcase), try_validate_utf8(utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaAaaAaaAaAaaAaaAaA +abc abc +abcdcba aBcDCbA +bbAbAAbA a +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select substr(utf8_binary, 2, 2), substr(utf8_lcase, 2, 2) from t5 +-- !query schema +struct +-- !query output +aA aA +aA aA +aA aA +bA +bc Bc +bc bc +el el +fd fd +it it +o o +o ̇o +om om +pa QL + + +-- !query +select substr(utf8_binary collate utf8_lcase, 2, 2), substr(utf8_lcase collate utf8_binary, 2, 2) from t5 +-- !query schema +struct +-- !query output +aA aA +aA aA +aA aA +bA +bc Bc +bc bc +el el +fd fd +it it +o o +o ̇o +om om +pa QL + + +-- !query +select right(utf8_binary, 2), right(utf8_lcase, 2) from t5 +-- !query schema +struct +-- !query output +aA aA +aA aA +aA aA +bA a +ba bA +bc bc +d2 d2 +e. e. +en ng +rk QL +y. y. +İo İo +İo ̇o + + +-- !query +select right(utf8_binary collate utf8_lcase, 2), right(utf8_lcase collate utf8_binary, 2) from t5 +-- !query schema +struct +-- !query output +aA aA +aA aA +aA aA +bA a +ba bA +bc bc +d2 d2 +e. e. +en ng +rk QL +y. y. +İo İo +İo ̇o + + +-- !query +select left(utf8_binary, '2' collate utf8_lcase), left(utf8_lcase, 2) from t5 +-- !query schema +struct +-- !query output +He He +So So +Sp SQ +aa aa +aa aa +aa aa +ab aB +ab ab +bb a +ef ef +ki si +İo i̇ +İo İo + + +-- !query +select left(utf8_binary collate utf8_lcase, 2), left(utf8_lcase collate utf8_binary, 2) from t5 +-- !query schema +struct +-- !query output +He He +So So +Sp SQ +aa aa +aa aa +aa aa +ab aB +ab ab +bb a +ef ef +ki si +İo i̇ +İo İo + + +-- !query +select rpad(utf8_binary, 8, utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select rpad(s, 8, utf8_binary) from t5 +-- !query schema +struct +-- !query output +Hello, w +Somethin +SparkSpa +aaAaAAaA +aaAaAAaA +aaAaAAaA +abcabcab +abcdcbaa +bbAbaAbA +efd2efd2 +kittenki +İoİoİoİo +İoİoİoİo + + +-- !query +select rpad(utf8_binary collate utf8_binary, 8, s collate utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select rpad(utf8_binary, 8, utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +Hello, w +Somethin +SparkSQL +aaAaAAaA +aaAaAAaA +aaAaAAaA +abcabcab +abcdcbaa +bbAbAAbA +efd2efd2 +kittensi +İoi̇oi̇o +İoİoİoİo + + +-- !query +select rpad(utf8_binary collate utf8_lcase, 8, utf8_lcase collate utf8_lcase) from t5 +-- !query schema +struct +-- !query output +Hello, w +Somethin +SparkSQL +aaAaAAaA +aaAaAAaA +aaAaAAaA +abcabcab +abcdcbaa +bbAbAAbA +efd2efd2 +kittensi +İoi̇oi̇o +İoİoİoİo + + +-- !query +select rpad(utf8_binary, 8, 'a'), rpad(utf8_lcase, 8, 'a') from t5 +-- !query schema +struct +-- !query output +Hello, w Hello, w +Somethin Somethin +Sparkaaa SQLaaaaa +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaA +abcaaaaa abcaaaaa +abcdcbaa aBcDCbAa +bbAbAAbA aaaaaaaa +efd2aaaa efd2aaaa +kittenaa sitTinga +İoaaaaaa i̇oaaaaa +İoaaaaaa İoaaaaaa + + +-- !query +select rpad(utf8_binary, 8, 'AaAA' collate utf8_lcase), rpad(utf8_lcase, 8, 'AAa' collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +Hello, w Hello, w +Somethin Somethin +SparkAaA SQLAAaAA +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaA +abcAaAAA abcAAaAA +abcdcbaA aBcDCbAA +bbAbAAbA aAAaAAaA +efd2AaAA efd2AAaA +kittenAa sitTingA +İoAaAAAa i̇oAAaAA +İoAaAAAa İoAAaAAa + + +-- !query +select lpad(utf8_binary, 8, utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select lpad(s, 8, utf8_binary) from t5 +-- !query schema +struct +-- !query output +Hello, w +Somethin +SpaSpark +aaAaAAaA +aaAaAAaA +aaAaAAaA +aabcdcba +abcababc +bbAbaAbA +efd2efd2 +kikitten +İoİoİoİo +İoİoİoİo + + +-- !query +select lpad(utf8_binary collate utf8_binary, 8, s collate utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select lpad(utf8_binary, 8, utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +Hello, w +SQLSpark +Somethin +aaAaAAaA +aaAaAAaA +aaAaAAaA +aabcdcba +abcababc +bbAbAAbA +efd2efd2 +i̇oi̇oİo +sikitten +İoİoİoİo + + +-- !query +select lpad(utf8_binary collate utf8_lcase, 8, utf8_lcase collate utf8_lcase) from t5 +-- !query schema +struct +-- !query output +Hello, w +SQLSpark +Somethin +aaAaAAaA +aaAaAAaA +aaAaAAaA +aabcdcba +abcababc +bbAbAAbA +efd2efd2 +i̇oi̇oİo +sikitten +İoİoİoİo + + +-- !query +select lpad(utf8_binary, 8, 'a'), lpad(utf8_lcase, 8, 'a') from t5 +-- !query schema +struct +-- !query output +Hello, w Hello, w +Somethin Somethin +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaA +aaaSpark aaaaaSQL +aaaaaabc aaaaaabc +aaaaaaİo aaaaaaİo +aaaaaaİo aaaaai̇o +aaaaefd2 aaaaefd2 +aabcdcba aaBcDCbA +aakitten asitTing +bbAbAAbA aaaaaaaa + + +-- !query +select lpad(utf8_binary, 8, 'AaAA' collate utf8_lcase), lpad(utf8_lcase, 8, 'AAa' collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +AaAAAabc AAaAAabc +AaAAAaİo AAaAAaİo +AaAAAaİo AAaAAi̇o +AaAAefd2 AAaAefd2 +AaASpark AAaAASQL +Aabcdcba AaBcDCbA +Aakitten AsitTing +Hello, w Hello, w +Somethin Somethin +aaAaAAaA aaAaAAaA +aaAaAAaA aaAaaAaA +aaAaAAaA aaAaaAaA +bbAbAAbA AAaAAaAa + + +-- !query +select locate(utf8_binary, utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select locate(s, utf8_binary) from t5 +-- !query schema +struct +-- !query output +0 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 + + +-- !query +select locate(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + } +} + + +-- !query +select locate(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output +0 +0 +0 +0 +0 +0 +0 +1 +1 +1 +1 +1 +1 + + +-- !query +select locate(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 3) from t5 +-- !query schema +struct +-- !query output +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +3 + + +-- !query +select locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"locate(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 3)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 79, + "fragment" : "locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3)" + } ] +} + + +-- !query +select locate(utf8_binary, 'a'), locate(utf8_lcase, 'a') from t5 +-- !query schema +struct +-- !query output +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 1 + + +-- !query +select locate(utf8_binary, 'AaAA' collate utf8_lcase, 4), locate(utf8_lcase, 'AAa' collate utf8_binary, 4) from t5 +-- !query schema +struct +-- !query output +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 +0 0 + + +-- !query +select TRIM(utf8_binary, utf8_lcase) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.IMPLICIT", + "sqlState" : "42P21" +} + + +-- !query +select TRIM(s, utf8_binary) from t5 +-- !query schema +struct +-- !query output + + + +-- !query +select TRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 -- !query schema struct<> -- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "COLLATION_MISMATCH.EXPLICIT", + "sqlState" : "42P21", + "messageParameters" : { + "explicitTypes" : "`string collate UTF8_LCASE`, `string`" + } +} + + +-- !query +select TRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output + + + + + + + + +BcDCbA +QL +a +i̇ +sitTing + + +-- !query +select TRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query schema +struct +-- !query output + + + + + + + + +QL +sitTing + + -- !query -insert into t5 values('11AB12AB13', 'AB', 2) +select TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 -- !query schema struct<> -- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(BOTH collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 74, + "fragment" : "TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + +-- !query +select TRIM('ABc', utf8_binary), TRIM('ABc', utf8_lcase) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAa +aaAaAAa +aaAaAAa +ab +abcdcba D +bbAbAAb +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + +-- !query +select TRIM('ABc' collate utf8_lcase, utf8_binary), TRIM('AAa' collate utf8_binary, utf8_lcase) from t5 +-- !query schema +struct +-- !query output + + + + + bc +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +d BcDCb +efd2 efd2 +kitten sitTing +İo i̇o +İo İo -- !query -select split_part(str, delimiter, partNum) from t5 +select BTRIM(utf8_binary, utf8_lcase) from t5 -- !query schema struct<> -- !query output @@ -517,7 +4024,27 @@ org.apache.spark.sql.AnalysisException -- !query -select split_part(str collate utf8_binary, delimiter collate utf8_lcase, partNum) from t5 +select BTRIM(s, utf8_binary) from t5 +-- !query schema +struct +-- !query output + + + + + + + + + + + + +a + + +-- !query +select BTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 -- !query schema struct<> -- !query output @@ -532,39 +4059,113 @@ org.apache.spark.sql.AnalysisException -- !query -select split_part(str collate utf8_binary, delimiter collate utf8_binary, partNum) from t5 +select BTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5 -- !query schema -struct +struct -- !query output -12 + + + + + + + + +bbAbAAbA +d +kitte +park +İ -- !query -drop table t5 +select BTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 -- !query schema -struct<> +struct -- !query output + + + + + + +bbAbAAb +kitte +park +İ + + -- !query -create table t6 (utf8_binary string collate utf8_binary, utf8_lcase string collate utf8_lcase, threshold int) using parquet +select BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 -- !query schema struct<> -- !query output - +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(BOTH collate(utf8_lcase, unicode_ai) FROM collate(utf8_binary, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} -- !query -insert into t6 values('kitten', 'sitting', 2) +select BTRIM('ABc', utf8_binary), BTRIM('ABc', utf8_lcase) from t5 -- !query schema -struct<> +struct -- !query output +AB +AB +AB B +ABc ABc +ABc ABc +ABc ABc +ABc ABc +ABc ABc +ABc ABc +Bc Bc +Bc Bc +Bc Bc +Bc Bc + +-- !query +select BTRIM('ABc' collate utf8_lcase, utf8_binary), BTRIM('AAa' collate utf8_binary, utf8_lcase) from t5 +-- !query schema +struct +-- !query output + + AA +ABc AAa +ABc AAa +ABc AAa +ABc AAa +ABc AAa +B AA +Bc +Bc +Bc +Bc AAa +c AA -- !query -select levenshtein(utf8_binary, utf8_lcase) from t6 +select LTRIM(utf8_binary, utf8_lcase) from t5 -- !query schema struct<> -- !query output @@ -576,7 +4177,15 @@ org.apache.spark.sql.AnalysisException -- !query -select levenshtein(utf8_binary collate utf8_binary, utf8_lcase collate utf8_lcase) from t6 +select LTRIM(s, utf8_binary) from t5 +-- !query schema +struct +-- !query output + + + +-- !query +select LTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 -- !query schema struct<> -- !query output @@ -585,21 +4194,119 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "`string collate UTF8_LCASE`, `string`" } } -- !query -select levenshtein(utf8_binary collate utf8_binary, utf8_lcase collate utf8_binary) from t6 +select LTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5 -- !query schema -struct +struct -- !query output -3 + + + + + + + + +BcDCbA +QL +a +i̇o +sitTing + + +-- !query +select LTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query schema +struct +-- !query output + + + + + + + + + + + +QL +sitTing + + +-- !query +select LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(LEADING collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + +-- !query +select LTRIM('ABc', utf8_binary), LTRIM('ABc', utf8_lcase) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAaA +aaAaAAaA +aaAaAAaA +abc +abcdcba DCbA +bbAbAAbA +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select LTRIM('ABc' collate utf8_lcase, utf8_binary), LTRIM('AAa' collate utf8_binary, utf8_lcase) from t5 +-- !query schema +struct +-- !query output + + + + + bc +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +dcba BcDCbA +efd2 efd2 +kitten sitTing +İo i̇o +İo İo -- !query -select levenshtein(utf8_binary, utf8_lcase, threshold) from t6 +select RTRIM(utf8_binary, utf8_lcase) from t5 -- !query schema struct<> -- !query output @@ -611,7 +4318,15 @@ org.apache.spark.sql.AnalysisException -- !query -select levenshtein(utf8_binary collate utf8_binary, utf8_lcase collate utf8_lcase, threshold) from t6 +select RTRIM(s, utf8_binary) from t5 +-- !query schema +struct +-- !query output + + + +-- !query +select RTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5 -- !query schema struct<> -- !query output @@ -620,17 +4335,123 @@ org.apache.spark.sql.AnalysisException "errorClass" : "COLLATION_MISMATCH.EXPLICIT", "sqlState" : "42P21", "messageParameters" : { - "explicitTypes" : "`string`, `string collate UTF8_LCASE`" + "explicitTypes" : "`string collate UTF8_LCASE`, `string`" } } -- !query -select levenshtein(utf8_binary collate utf8_binary, utf8_lcase collate utf8_binary, threshold) from t6 +select RTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5 +-- !query schema +struct +-- !query output + + + + + + + + +SQL +a +aBcDCbA +i̇ +sitTing + + +-- !query +select RTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5 +-- !query schema +struct +-- !query output + + + + + + + + + + + +SQL +sitTing + + +-- !query +select RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(TRAILING collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + +-- !query +select RTRIM('ABc', utf8_binary), RTRIM('ABc', utf8_lcase) from t5 +-- !query schema +struct +-- !query output +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +aaAaAAa +aaAaAAa +aaAaAAa +ab +abcdcba aBcD +bbAbAAb +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +select RTRIM('ABc' collate utf8_lcase, utf8_binary), RTRIM('AAa' collate utf8_binary, utf8_lcase) from t5 +-- !query schema +struct +-- !query output + + + + + abc +Hello, world! Nice day. Hello, world! Nice day. +Something else. Nothing here. Something else. Nothing here. +Spark SQL +abcd aBcDCb +efd2 efd2 +kitten sitTing +İo i̇o +İo İo + + +-- !query +drop table t5 -- !query schema -struct +struct<> -- !query output --1 + -- !query @@ -639,3 +4460,35 @@ drop table t6 struct<> -- !query output + + +-- !query +drop table t7 +-- !query schema +struct<> +-- !query output + + + +-- !query +drop table t8 +-- !query schema +struct<> +-- !query output + + + +-- !query +drop table t9 +-- !query schema +struct<> +-- !query output + + + +-- !query +drop table t10 +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/resources/sql-tests/results/execute-immediate.sql.out b/sql/core/src/test/resources/sql-tests/results/execute-immediate.sql.out index 9249d7eb3e517..21ea4436f4fa5 100644 --- a/sql/core/src/test/resources/sql-tests/results/execute-immediate.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/execute-immediate.sql.out @@ -392,7 +392,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'invalid_cast_error_expected'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -603,7 +602,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'name1'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index e2b9e11eb6331..5471dafaec8eb 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -1142,9 +1142,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR` when cast to interval day to hour: 20 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "20 15:40:32.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR`", + "typeName" : "interval day to hour" }, "queryContext" : [ { "objectType" : "", @@ -1163,9 +1167,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h:m`, `INTERVAL [+|-]'[+|-]d h:m' DAY TO MINUTE` when cast to interval day to minute: 20 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "20 15:40:32.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h:m`, `INTERVAL [+|-]'[+|-]d h:m' DAY TO MINUTE`", + "typeName" : "interval day to minute" }, "queryContext" : [ { "objectType" : "", @@ -1184,9 +1192,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE` when cast to interval hour to minute: 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "15:40:32.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE`", + "typeName" : "interval hour to minute" }, "queryContext" : [ { "objectType" : "", @@ -1205,9 +1217,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND` when cast to interval hour to second: 15:40.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "15:40.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND`", + "typeName" : "interval hour to second" }, "queryContext" : [ { "objectType" : "", @@ -1226,9 +1242,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND` when cast to interval hour to second: 15:40, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "15:40", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND`", + "typeName" : "interval hour to second" }, "queryContext" : [ { "objectType" : "", @@ -1247,9 +1267,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND` when cast to interval minute to second: 20 40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "20 40:32.99899999", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND`", + "typeName" : "interval minute to second" }, "queryContext" : [ { "objectType" : "", @@ -1678,9 +1702,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.INTERVAL_PARSING", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Error parsing interval year-month string: integer overflow" + "input" : "178956970-8", + "interval" : "year-month" }, "queryContext" : [ { "objectType" : "", @@ -2161,9 +2187,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match year-month format of `[+|-]y-m`, `INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH` when cast to interval year to month: -\t2-2\t" + "input" : "-\t2-2\t", + "intervalStr" : "year-month", + "supportedFormat" : "`[+|-]y-m`, `INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH`", + "typeName" : "interval year to month" }, "queryContext" : [ { "objectType" : "", @@ -2190,9 +2220,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h:m:s.n`, `INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND` when cast to interval day to second: \n-\t10\t 12:34:46.789\t, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "\n-\t10\t 12:34:46.789\t", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h:m:s.n`, `INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND`", + "typeName" : "interval day to second" }, "queryContext" : [ { "objectType" : "", @@ -2679,7 +2713,7 @@ SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1.0D struct<> -- !query output java.lang.ArithmeticException -not in range +rounded value is out of range for input 2.147483648E9 and rounding mode HALF_UP -- !query @@ -2759,7 +2793,7 @@ SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1.0D struct<> -- !query output java.lang.ArithmeticException -not in range +rounded value is out of range for input 9.223372036854776E18 and rounding mode HALF_UP -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out index ced8d6398a66f..11bafb2cf63c9 100644 --- a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out @@ -1878,6 +1878,33 @@ struct 1 2 3 +-- !query +select 1 +from t1 as t_outer +left join + lateral( + select b1,b2 + from + ( + select + t2.c1 as b1, + 1 as b2 + from t2 + union + select t_outer.c1 as b1, + null as b2 + ) as t_inner + where (t_inner.b1 < t_outer.c2 or t_inner.b1 is null) + and t_inner.b1 = t_outer.c1 + order by t_inner.b1,t_inner.b2 desc limit 1 + ) as lateral_table +-- !query schema +struct<1:int> +-- !query output +1 +1 + + -- !query DROP VIEW t1 -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out index ca48e851e717c..2c16d961b1313 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out @@ -32,6 +32,7 @@ BUCKETS false BY false BYTE false CACHE false +CALL false CALLED false CASCADE false CASE false @@ -141,6 +142,7 @@ HAVING false HOUR false HOURS false IDENTIFIER false +IDENTITY false IF false IGNORE false ILIKE false @@ -148,6 +150,7 @@ IMMEDIATE false IMPORT false IN false INCLUDE false +INCREMENT false INDEX false INDEXES false INNER false @@ -163,6 +166,7 @@ INTO false INVOKER false IS false ITEMS false +ITERATE false JOIN false KEYS false LANGUAGE false @@ -170,6 +174,7 @@ LAST false LATERAL false LAZY false LEADING false +LEAVE false LEFT false LIKE false LIMIT false @@ -249,6 +254,7 @@ REFERENCES false REFRESH false RENAME false REPAIR false +REPEAT false REPEATABLE false REPLACE false RESET false @@ -334,6 +340,7 @@ UNKNOWN false UNLOCK false UNPIVOT false UNSET false +UNTIL false UPDATE false USE false USER false diff --git a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out new file mode 100644 index 0000000000000..38436b0941034 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out @@ -0,0 +1,576 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +drop table if exists t +-- !query schema +struct<> +-- !query output + + + +-- !query +create table t(x int, y string) using csv +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t values (0, 'abc'), (1, 'def') +-- !query schema +struct<> +-- !query output + + + +-- !query +drop table if exists other +-- !query schema +struct<> +-- !query output + + + +-- !query +create table other(a int, b int) using json +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into other values (1, 1), (1, 2), (2, 4) +-- !query schema +struct<> +-- !query output + + + +-- !query +drop table if exists st +-- !query schema +struct<> +-- !query output + + + +-- !query +create table st(x int, col struct) using parquet +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into st values (1, (2, 3)) +-- !query schema +struct<> +-- !query output + + + +-- !query +table t +|> select 1 as x +-- !query schema +struct +-- !query output +1 +1 + + +-- !query +table t +|> select x, y +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> select x, y +|> select x + length(y) as z +-- !query schema +struct +-- !query output +3 +4 + + +-- !query +values (0), (1) tab(col) +|> select col * 2 as result +-- !query schema +struct +-- !query output +0 +2 + + +-- !query +(select * from t union all select * from t) +|> select x + length(y) as result +-- !query schema +struct +-- !query output +3 +3 +4 +4 + + +-- !query +(table t + |> select x, y + |> select x) +union all +select x from t where x < 1 +-- !query schema +struct +-- !query output +0 +0 +1 + + +-- !query +(select col from st) +|> select col.i1 +-- !query schema +struct +-- !query output +2 + + +-- !query +table st +|> select st.col.i1 +-- !query schema +struct +-- !query output +2 + + +-- !query +table t +|> select (select a from other where x = a limit 1) as result +-- !query schema +struct +-- !query output +1 +NULL + + +-- !query +select (values (0) tab(col) |> select col) as result +-- !query schema +struct +-- !query output +0 + + +-- !query +table t +|> select (select any_value(a) from other where x = a limit 1) as result +-- !query schema +struct +-- !query output +1 +NULL + + +-- !query +table t +|> select x + length(x) as z, z + 1 as plus_one +-- !query schema +struct +-- !query output +1 2 +2 3 + + +-- !query +table t +|> select first_value(x) over (partition by y) as result +-- !query schema +struct +-- !query output +0 +1 + + +-- !query +select 1 x, 2 y, 3 z +|> select 1 + sum(x) over (), + avg(y) over (), + x, + avg(x+1) over (partition by y order by z) AS a2 +|> select a2 +-- !query schema +struct +-- !query output +2.0 + + +-- !query +table t +|> select x, count(*) over () +|> select x +-- !query schema +struct +-- !query output +0 +1 + + +-- !query +table t +|> select distinct x, y +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> select * +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> select * except (y) +-- !query schema +struct +-- !query output +0 +1 + + +-- !query +table t +|> select /*+ repartition(3) */ * +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> select /*+ repartition(3) */ distinct x +-- !query schema +struct +-- !query output +0 +1 + + +-- !query +table t +|> select /*+ repartition(3) */ all x +-- !query schema +struct +-- !query output +0 +1 + + +-- !query +table t +|> select sum(x) as result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "PIPE_OPERATOR_SELECT_CONTAINS_AGGREGATE_FUNCTION", + "sqlState" : "0A000", + "messageParameters" : { + "expr" : "sum(x#x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 19, + "stopIndex" : 24, + "fragment" : "sum(x)" + } ] +} + + +-- !query +table t +|> select y, length(y) + sum(x) as result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "PIPE_OPERATOR_SELECT_CONTAINS_AGGREGATE_FUNCTION", + "sqlState" : "0A000", + "messageParameters" : { + "expr" : "sum(x#x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 34, + "stopIndex" : 39, + "fragment" : "sum(x)" + } ] +} + + +-- !query +table t +|> where true +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> where x + length(y) < 4 +-- !query schema +struct +-- !query output +0 abc + + +-- !query +table t +|> where x + length(y) < 4 +|> where x + length(y) < 3 +-- !query schema +struct +-- !query output + + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where x = 1 +-- !query schema +struct +-- !query output +1 3 + + +-- !query +table t +|> where t.x = 1 +-- !query schema +struct +-- !query output +1 def + + +-- !query +table t +|> where spark_catalog.default.t.x = 1 +-- !query schema +struct +-- !query output +1 def + + +-- !query +(select col from st) +|> where col.i1 = 1 +-- !query schema +struct> +-- !query output + + + +-- !query +table st +|> where st.col.i1 = 2 +-- !query schema +struct> +-- !query output +1 {"i1":2,"i2":3} + + +-- !query +table t +|> where exists (select a from other where x = a limit 1) +-- !query schema +struct +-- !query output +1 def + + +-- !query +table t +|> where (select any_value(a) from other where x = a limit 1) = 1 +-- !query schema +struct +-- !query output +1 def + + +-- !query +table t +|> where sum(x) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"(sum(x) = 1)\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 27, + "fragment" : "table t\n|> where sum(x) = 1" + } ] +} + + +-- !query +table t +|> where y = 'abc' or length(y) + sum(x) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"((y = abc) OR ((length(y) + sum(x)) = 1))\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 52, + "fragment" : "table t\n|> where y = 'abc' or length(y) + sum(x) = 1" + } ] +} + + +-- !query +table t +|> where first_value(x) over (partition by y) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +select * from t where first_value(x) over (partition by y) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +table t +|> select x, length(y) as z +|> where x + length(y) < 4 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `z`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 57, + "stopIndex" : 57, + "fragment" : "y" + } ] +} + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where sum(length(y)) = 3 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `sum_len`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 77, + "stopIndex" : 77, + "fragment" : "y" + } ] +} + + +-- !query +drop table t +-- !query schema +struct<> +-- !query output + + + +-- !query +drop table other +-- !query schema +struct<> +-- !query output + + + +-- !query +drop table st +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/boolean.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/boolean.sql.out index 12660768b95cb..052e7b4f25224 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/boolean.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/boolean.sql.out @@ -57,7 +57,6 @@ org.apache.spark.SparkRuntimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'test'", "sourceType" : "\"STRING\"", "targetType" : "\"BOOLEAN\"" @@ -90,7 +89,6 @@ org.apache.spark.SparkRuntimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'foo'", "sourceType" : "\"STRING\"", "targetType" : "\"BOOLEAN\"" @@ -131,7 +129,6 @@ org.apache.spark.SparkRuntimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'yeah'", "sourceType" : "\"STRING\"", "targetType" : "\"BOOLEAN\"" @@ -172,7 +169,6 @@ org.apache.spark.SparkRuntimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'nay'", "sourceType" : "\"STRING\"", "targetType" : "\"BOOLEAN\"" @@ -197,7 +193,6 @@ org.apache.spark.SparkRuntimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'on'", "sourceType" : "\"STRING\"", "targetType" : "\"BOOLEAN\"" @@ -222,7 +217,6 @@ org.apache.spark.SparkRuntimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'off'", "sourceType" : "\"STRING\"", "targetType" : "\"BOOLEAN\"" @@ -247,7 +241,6 @@ org.apache.spark.SparkRuntimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'of'", "sourceType" : "\"STRING\"", "targetType" : "\"BOOLEAN\"" @@ -272,7 +265,6 @@ org.apache.spark.SparkRuntimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'o'", "sourceType" : "\"STRING\"", "targetType" : "\"BOOLEAN\"" @@ -297,7 +289,6 @@ org.apache.spark.SparkRuntimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'on_'", "sourceType" : "\"STRING\"", "targetType" : "\"BOOLEAN\"" @@ -322,7 +313,6 @@ org.apache.spark.SparkRuntimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'off_'", "sourceType" : "\"STRING\"", "targetType" : "\"BOOLEAN\"" @@ -355,7 +345,6 @@ org.apache.spark.SparkRuntimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'11'", "sourceType" : "\"STRING\"", "targetType" : "\"BOOLEAN\"" @@ -388,7 +377,6 @@ org.apache.spark.SparkRuntimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'000'", "sourceType" : "\"STRING\"", "targetType" : "\"BOOLEAN\"" @@ -413,7 +401,6 @@ org.apache.spark.SparkRuntimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "''", "sourceType" : "\"STRING\"", "targetType" : "\"BOOLEAN\"" @@ -535,7 +522,6 @@ org.apache.spark.SparkRuntimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "' tru e '", "sourceType" : "\"STRING\"", "targetType" : "\"BOOLEAN\"" @@ -560,7 +546,6 @@ org.apache.spark.SparkRuntimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "''", "sourceType" : "\"STRING\"", "targetType" : "\"BOOLEAN\"" diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out index 6b4b343d9ccae..1a15610b4dede 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out @@ -97,7 +97,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'N A N'", "sourceType" : "\"STRING\"", "targetType" : "\"FLOAT\"" @@ -122,7 +121,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'NaN x'", "sourceType" : "\"STRING\"", "targetType" : "\"FLOAT\"" @@ -147,7 +145,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "' INFINITY x'", "sourceType" : "\"STRING\"", "targetType" : "\"FLOAT\"" @@ -196,7 +193,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'nan'", "sourceType" : "\"STRING\"", "targetType" : "\"DECIMAL(10,0)\"" @@ -393,7 +389,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"FLOAT\"", "targetType" : "\"INT\"", "value" : "2.14748365E9" @@ -419,7 +414,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"FLOAT\"", "targetType" : "\"INT\"", "value" : "-2.1474839E9" @@ -461,7 +455,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"FLOAT\"", "targetType" : "\"BIGINT\"", "value" : "-9.22338E18" diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out.java21 b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out.java21 index 6126411071bc1..3c2189c399639 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out.java21 +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float4.sql.out.java21 @@ -97,7 +97,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'N A N'", "sourceType" : "\"STRING\"", "targetType" : "\"FLOAT\"" @@ -122,7 +121,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'NaN x'", "sourceType" : "\"STRING\"", "targetType" : "\"FLOAT\"" @@ -147,7 +145,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "' INFINITY x'", "sourceType" : "\"STRING\"", "targetType" : "\"FLOAT\"" @@ -196,7 +193,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'nan'", "sourceType" : "\"STRING\"", "targetType" : "\"DECIMAL(10,0)\"" @@ -393,7 +389,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"FLOAT\"", "targetType" : "\"INT\"", "value" : "2.1474836E9" @@ -419,7 +414,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"FLOAT\"", "targetType" : "\"INT\"", "value" : "-2.147484E9" @@ -461,7 +455,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"FLOAT\"", "targetType" : "\"BIGINT\"", "value" : "-9.22338E18" diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out index e1b880f343709..b1a114bea30ee 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/float8.sql.out @@ -129,7 +129,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'N A N'", "sourceType" : "\"STRING\"", "targetType" : "\"DOUBLE\"" @@ -154,7 +153,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'NaN x'", "sourceType" : "\"STRING\"", "targetType" : "\"DOUBLE\"" @@ -179,7 +177,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "' INFINITY x'", "sourceType" : "\"STRING\"", "targetType" : "\"DOUBLE\"" @@ -228,7 +225,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'nan'", "sourceType" : "\"STRING\"", "targetType" : "\"DECIMAL(10,0)\"" @@ -898,7 +894,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"DOUBLE\"", "targetType" : "\"BIGINT\"", "value" : "-9.22337203685478E18D" diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out index f6e4bd8bd7e08..5e8abc273b125 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out @@ -737,7 +737,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"BIGINT\"", "targetType" : "\"INT\"", "value" : "4567890123456789L" @@ -763,7 +762,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"BIGINT\"", "targetType" : "\"SMALLINT\"", "value" : "4567890123456789L" @@ -809,7 +807,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"DOUBLE\"", "targetType" : "\"BIGINT\"", "value" : "9.223372036854776E20D" @@ -898,7 +895,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"BIGINT\"", "targetType" : "\"INT\"", "value" : "-9223372036854775808L" diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out.java21 b/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out.java21 index ee3f8625da8a4..e7df03dc8cadd 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out.java21 +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/int8.sql.out.java21 @@ -737,7 +737,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"BIGINT\"", "targetType" : "\"INT\"", "value" : "4567890123456789L" @@ -763,7 +762,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"BIGINT\"", "targetType" : "\"SMALLINT\"", "value" : "4567890123456789L" @@ -809,7 +807,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"DOUBLE\"", "targetType" : "\"BIGINT\"", "value" : "9.223372036854776E20D" @@ -898,7 +895,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"BIGINT\"", "targetType" : "\"INT\"", "value" : "-9223372036854775808L" diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/interval.sql.out index bff615e22af0b..3855d922361bc 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/interval.sql.out @@ -102,9 +102,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR` when cast to interval day to hour: 1 2:03, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR`", + "typeName" : "interval day to hour" }, "queryContext" : [ { "objectType" : "", @@ -123,9 +127,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR` when cast to interval day to hour: 1 2:03:04, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03:04", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR`", + "typeName" : "interval day to hour" }, "queryContext" : [ { "objectType" : "", @@ -152,9 +160,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h:m`, `INTERVAL [+|-]'[+|-]d h:m' DAY TO MINUTE` when cast to interval day to minute: 1 2:03:04, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03:04", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h:m`, `INTERVAL [+|-]'[+|-]d h:m' DAY TO MINUTE`", + "typeName" : "interval day to minute" }, "queryContext" : [ { "objectType" : "", @@ -173,9 +185,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]d h:m:s.n`, `INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND` when cast to interval day to second: 1 2:03, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]d h:m:s.n`, `INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND`", + "typeName" : "interval day to second" }, "queryContext" : [ { "objectType" : "", @@ -202,9 +218,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE` when cast to interval hour to minute: 1 2:03, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE`", + "typeName" : "interval hour to minute" }, "queryContext" : [ { "objectType" : "", @@ -223,9 +243,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE` when cast to interval hour to minute: 1 2:03:04, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03:04", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE`", + "typeName" : "interval hour to minute" }, "queryContext" : [ { "objectType" : "", @@ -244,9 +268,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND` when cast to interval hour to second: 1 2:03, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND`", + "typeName" : "interval hour to second" }, "queryContext" : [ { "objectType" : "", @@ -265,9 +293,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND` when cast to interval hour to second: 1 2:03:04, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03:04", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND`", + "typeName" : "interval hour to second" }, "queryContext" : [ { "objectType" : "", @@ -286,9 +318,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND` when cast to interval minute to second: 1 2:03, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND`", + "typeName" : "interval minute to second" }, "queryContext" : [ { "objectType" : "", @@ -307,9 +343,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0063", + "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE", + "sqlState" : "22006", "messageParameters" : { - "msg" : "Interval string does not match day-time format of `[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND` when cast to interval minute to second: 1 2:03:04, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0." + "input" : "1 2:03:04", + "intervalStr" : "day-time", + "supportedFormat" : "`[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND`", + "typeName" : "interval minute to second" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out index 37b8a3e8fd19c..0a940f5f3c74a 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out @@ -66,7 +66,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'four: 2'", "sourceType" : "\"STRING\"", "targetType" : "\"BIGINT\"" @@ -91,7 +90,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'four: 2'", "sourceType" : "\"STRING\"", "targetType" : "\"BIGINT\"" diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out index 7c920bbd32b3c..94692a57300f9 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out @@ -700,7 +700,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'foo'", "sourceType" : "\"STRING\"", "targetType" : "\"DOUBLE\"" diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out index 6cf5e69758d2a..352c5f05cb06c 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part2.sql.out @@ -489,7 +489,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'NaN'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" diff --git a/sql/core/src/test/resources/sql-tests/results/random.sql.out b/sql/core/src/test/resources/sql-tests/results/random.sql.out index 16984de3ff257..0b4e5e078ee15 100644 --- a/sql/core/src/test/resources/sql-tests/results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/random.sql.out @@ -113,3 +113,472 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "fragment" : "rand('1')" } ] } + + +-- !query +SELECT uniform(0, 1, 0) AS result +-- !query schema +struct +-- !query output +0 + + +-- !query +SELECT uniform(0, 10, 0) AS result +-- !query schema +struct +-- !query output +7 + + +-- !query +SELECT uniform(0L, 10L, 0) AS result +-- !query schema +struct +-- !query output +7 + + +-- !query +SELECT uniform(0, 10L, 0) AS result +-- !query schema +struct +-- !query output +7 + + +-- !query +SELECT uniform(0, 10S, 0) AS result +-- !query schema +struct +-- !query output +7 + + +-- !query +SELECT uniform(10, 20, 0) AS result +-- !query schema +struct +-- !query output +17 + + +-- !query +SELECT uniform(10.0F, 20.0F, 0) AS result +-- !query schema +struct +-- !query output +17.604954 + + +-- !query +SELECT uniform(10.0D, 20.0D, CAST(3 / 7 AS LONG)) AS result +-- !query schema +struct +-- !query output +17.604953758285916 + + +-- !query +SELECT uniform(10, 20.0F, 0) AS result +-- !query schema +struct +-- !query output +17.604954 + + +-- !query +SELECT uniform(10, 20, 0) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query schema +struct +-- !query output +15 +16 +17 + + +-- !query +SELECT uniform(10, 20.0F) IS NOT NULL AS result +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT uniform(NULL, 1, 0) AS result +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT uniform(0, NULL, 0) AS result +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT uniform(0, 1, NULL) AS result +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT uniform(10, 20, col) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"col\"", + "inputName" : "seed", + "inputType" : "integer or floating-point", + "sqlExpr" : "\"uniform(10, 20, col)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 27, + "fragment" : "uniform(10, 20, col)" + } ] +} + + +-- !query +SELECT uniform(col, 10, 0) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"col\"", + "inputName" : "min", + "inputType" : "integer or floating-point", + "sqlExpr" : "\"uniform(col, 10, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 26, + "fragment" : "uniform(col, 10, 0)" + } ] +} + + +-- !query +SELECT uniform(10) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "1", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "[2, 3]", + "functionName" : "`uniform`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 18, + "fragment" : "uniform(10)" + } ] +} + + +-- !query +SELECT uniform(10, 20, 30, 40) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "4", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "[2, 3]", + "functionName" : "`uniform`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 30, + "fragment" : "uniform(10, 20, 30, 40)" + } ] +} + + +-- !query +SELECT randstr(1, 0) AS result +-- !query schema +struct +-- !query output +c + + +-- !query +SELECT randstr(5, 0) AS result +-- !query schema +struct +-- !query output +ceV0P + + +-- !query +SELECT randstr(10, 0) AS result +-- !query schema +struct +-- !query output +ceV0PXaR2I + + +-- !query +SELECT randstr(10S, 0) AS result +-- !query schema +struct +-- !query output +ceV0PXaR2I + + +-- !query +SELECT randstr(10, 0) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query schema +struct +-- !query output +ceV0PXaR2I +fYxVfArnv7 +iSIv0VT2XL + + +-- !query +SELECT randstr(10) IS NOT NULL AS result +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT randstr(10L, 0) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"10\"", + "inputType" : "\"BIGINT\"", + "paramIndex" : "first", + "requiredType" : "INT or SMALLINT", + "sqlExpr" : "\"randstr(10, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 22, + "fragment" : "randstr(10L, 0)" + } ] +} + + +-- !query +SELECT randstr(10.0F, 0) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"10.0\"", + "inputType" : "\"FLOAT\"", + "paramIndex" : "first", + "requiredType" : "INT or SMALLINT", + "sqlExpr" : "\"randstr(10.0, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 24, + "fragment" : "randstr(10.0F, 0)" + } ] +} + + +-- !query +SELECT randstr(10.0D, 0) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"10.0\"", + "inputType" : "\"DOUBLE\"", + "paramIndex" : "first", + "requiredType" : "INT or SMALLINT", + "sqlExpr" : "\"randstr(10.0, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 24, + "fragment" : "randstr(10.0D, 0)" + } ] +} + + +-- !query +SELECT randstr(NULL, 0) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"NULL\"", + "inputType" : "\"VOID\"", + "paramIndex" : "first", + "requiredType" : "INT or SMALLINT", + "sqlExpr" : "\"randstr(NULL, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 23, + "fragment" : "randstr(NULL, 0)" + } ] +} + + +-- !query +SELECT randstr(0, NULL) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"NULL\"", + "inputType" : "\"VOID\"", + "paramIndex" : "second", + "requiredType" : "INT or SMALLINT", + "sqlExpr" : "\"randstr(0, NULL)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 23, + "fragment" : "randstr(0, NULL)" + } ] +} + + +-- !query +SELECT randstr(col, 0) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"col\"", + "inputName" : "length", + "inputType" : "INT or SMALLINT", + "sqlExpr" : "\"randstr(col, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 22, + "fragment" : "randstr(col, 0)" + } ] +} + + +-- !query +SELECT randstr(10, col) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"col\"", + "inputName" : "seedExpression", + "inputType" : "INT or SMALLINT", + "sqlExpr" : "\"randstr(10, col)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 23, + "fragment" : "randstr(10, col)" + } ] +} + + +-- !query +SELECT randstr(10, 0, 1) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "3", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "[1, 2]", + "functionName" : "`randstr`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 24, + "fragment" : "randstr(10, 0, 1)" + } ] +} diff --git a/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out index 73d3ec737085f..249a03fdfbf87 100644 --- a/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out @@ -943,7 +943,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'hello'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" @@ -990,7 +989,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"INT\"", "targetType" : "\"SMALLINT\"", "value" : "100000" @@ -1104,7 +1102,6 @@ org.apache.spark.SparkArithmeticException "errorClass" : "CAST_OVERFLOW", "sqlState" : "22003", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "sourceType" : "\"DOUBLE\"", "targetType" : "\"INT\"", "value" : "1.0E10D" @@ -1171,7 +1168,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'hello'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" diff --git a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out index 25aaadfc8e783..cd94674d2bf2b 100644 --- a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out @@ -384,7 +384,6 @@ org.apache.spark.SparkDateTimeException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'1'", "sourceType" : "\"STRING\"", "targetType" : "\"TIMESTAMP_NTZ\"" diff --git a/sql/core/src/test/resources/sql-tests/results/view-schema-binding-config.sql.out b/sql/core/src/test/resources/sql-tests/results/view-schema-binding-config.sql.out index b0d497e070477..4288457d56b40 100644 --- a/sql/core/src/test/resources/sql-tests/results/view-schema-binding-config.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/view-schema-binding-config.sql.out @@ -701,7 +701,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'a'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" diff --git a/sql/core/src/test/resources/sql-tests/results/view-schema-compensation.sql.out b/sql/core/src/test/resources/sql-tests/results/view-schema-compensation.sql.out index ffd1fbec47bbb..641365309d51c 100644 --- a/sql/core/src/test/resources/sql-tests/results/view-schema-compensation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/view-schema-compensation.sql.out @@ -187,7 +187,6 @@ org.apache.spark.SparkNumberFormatException "errorClass" : "CAST_INVALID_INPUT", "sqlState" : "22018", "messageParameters" : { - "ansiConfig" : "\"spark.sql.ansi.enabled\"", "expression" : "'a'", "sourceType" : "\"STRING\"", "targetType" : "\"INT\"" diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/explain.txt index 4bf7de791b279..96bed479d2e06 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/explain.txt @@ -175,125 +175,125 @@ Input [6]: [i_product_name#12, i_brand#9, i_class#10, i_category#11, sum#21, cou Keys [4]: [i_product_name#12, i_brand#9, i_class#10, i_category#11] Functions [1]: [avg(qoh#18)] Aggregate Attributes [1]: [avg(qoh#18)#23] -Results [5]: [i_product_name#12 AS i_product_name#24, i_brand#9 AS i_brand#25, i_class#10 AS i_class#26, i_category#11 AS i_category#27, avg(qoh#18)#23 AS qoh#28] +Results [5]: [i_product_name#12, i_brand#9, i_class#10, i_category#11, avg(qoh#18)#23 AS qoh#24] (27) ReusedExchange [Reuses operator id: 23] -Output [6]: [i_product_name#29, i_brand#30, i_class#31, i_category#32, sum#33, count#34] +Output [6]: [i_product_name#25, i_brand#26, i_class#27, i_category#28, sum#29, count#30] (28) HashAggregate [codegen id : 16] -Input [6]: [i_product_name#29, i_brand#30, i_class#31, i_category#32, sum#33, count#34] -Keys [4]: [i_product_name#29, i_brand#30, i_class#31, i_category#32] -Functions [1]: [avg(inv_quantity_on_hand#35)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#35)#17] -Results [4]: [i_product_name#29, i_brand#30, i_class#31, avg(inv_quantity_on_hand#35)#17 AS qoh#36] +Input [6]: [i_product_name#25, i_brand#26, i_class#27, i_category#28, sum#29, count#30] +Keys [4]: [i_product_name#25, i_brand#26, i_class#27, i_category#28] +Functions [1]: [avg(inv_quantity_on_hand#31)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#31)#17] +Results [4]: [i_product_name#25, i_brand#26, i_class#27, avg(inv_quantity_on_hand#31)#17 AS qoh#32] (29) HashAggregate [codegen id : 16] -Input [4]: [i_product_name#29, i_brand#30, i_class#31, qoh#36] -Keys [3]: [i_product_name#29, i_brand#30, i_class#31] -Functions [1]: [partial_avg(qoh#36)] -Aggregate Attributes [2]: [sum#37, count#38] -Results [5]: [i_product_name#29, i_brand#30, i_class#31, sum#39, count#40] +Input [4]: [i_product_name#25, i_brand#26, i_class#27, qoh#32] +Keys [3]: [i_product_name#25, i_brand#26, i_class#27] +Functions [1]: [partial_avg(qoh#32)] +Aggregate Attributes [2]: [sum#33, count#34] +Results [5]: [i_product_name#25, i_brand#26, i_class#27, sum#35, count#36] (30) Exchange -Input [5]: [i_product_name#29, i_brand#30, i_class#31, sum#39, count#40] -Arguments: hashpartitioning(i_product_name#29, i_brand#30, i_class#31, 5), ENSURE_REQUIREMENTS, [plan_id=5] +Input [5]: [i_product_name#25, i_brand#26, i_class#27, sum#35, count#36] +Arguments: hashpartitioning(i_product_name#25, i_brand#26, i_class#27, 5), ENSURE_REQUIREMENTS, [plan_id=5] (31) HashAggregate [codegen id : 17] -Input [5]: [i_product_name#29, i_brand#30, i_class#31, sum#39, count#40] -Keys [3]: [i_product_name#29, i_brand#30, i_class#31] -Functions [1]: [avg(qoh#36)] -Aggregate Attributes [1]: [avg(qoh#36)#41] -Results [5]: [i_product_name#29, i_brand#30, i_class#31, null AS i_category#42, avg(qoh#36)#41 AS qoh#43] +Input [5]: [i_product_name#25, i_brand#26, i_class#27, sum#35, count#36] +Keys [3]: [i_product_name#25, i_brand#26, i_class#27] +Functions [1]: [avg(qoh#32)] +Aggregate Attributes [1]: [avg(qoh#32)#37] +Results [5]: [i_product_name#25, i_brand#26, i_class#27, null AS i_category#38, avg(qoh#32)#37 AS qoh#39] (32) ReusedExchange [Reuses operator id: 23] -Output [6]: [i_product_name#44, i_brand#45, i_class#46, i_category#47, sum#48, count#49] +Output [6]: [i_product_name#40, i_brand#41, i_class#42, i_category#43, sum#44, count#45] (33) HashAggregate [codegen id : 25] -Input [6]: [i_product_name#44, i_brand#45, i_class#46, i_category#47, sum#48, count#49] -Keys [4]: [i_product_name#44, i_brand#45, i_class#46, i_category#47] -Functions [1]: [avg(inv_quantity_on_hand#50)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#50)#17] -Results [3]: [i_product_name#44, i_brand#45, avg(inv_quantity_on_hand#50)#17 AS qoh#51] +Input [6]: [i_product_name#40, i_brand#41, i_class#42, i_category#43, sum#44, count#45] +Keys [4]: [i_product_name#40, i_brand#41, i_class#42, i_category#43] +Functions [1]: [avg(inv_quantity_on_hand#46)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#46)#17] +Results [3]: [i_product_name#40, i_brand#41, avg(inv_quantity_on_hand#46)#17 AS qoh#47] (34) HashAggregate [codegen id : 25] -Input [3]: [i_product_name#44, i_brand#45, qoh#51] -Keys [2]: [i_product_name#44, i_brand#45] -Functions [1]: [partial_avg(qoh#51)] -Aggregate Attributes [2]: [sum#52, count#53] -Results [4]: [i_product_name#44, i_brand#45, sum#54, count#55] +Input [3]: [i_product_name#40, i_brand#41, qoh#47] +Keys [2]: [i_product_name#40, i_brand#41] +Functions [1]: [partial_avg(qoh#47)] +Aggregate Attributes [2]: [sum#48, count#49] +Results [4]: [i_product_name#40, i_brand#41, sum#50, count#51] (35) Exchange -Input [4]: [i_product_name#44, i_brand#45, sum#54, count#55] -Arguments: hashpartitioning(i_product_name#44, i_brand#45, 5), ENSURE_REQUIREMENTS, [plan_id=6] +Input [4]: [i_product_name#40, i_brand#41, sum#50, count#51] +Arguments: hashpartitioning(i_product_name#40, i_brand#41, 5), ENSURE_REQUIREMENTS, [plan_id=6] (36) HashAggregate [codegen id : 26] -Input [4]: [i_product_name#44, i_brand#45, sum#54, count#55] -Keys [2]: [i_product_name#44, i_brand#45] -Functions [1]: [avg(qoh#51)] -Aggregate Attributes [1]: [avg(qoh#51)#56] -Results [5]: [i_product_name#44, i_brand#45, null AS i_class#57, null AS i_category#58, avg(qoh#51)#56 AS qoh#59] +Input [4]: [i_product_name#40, i_brand#41, sum#50, count#51] +Keys [2]: [i_product_name#40, i_brand#41] +Functions [1]: [avg(qoh#47)] +Aggregate Attributes [1]: [avg(qoh#47)#52] +Results [5]: [i_product_name#40, i_brand#41, null AS i_class#53, null AS i_category#54, avg(qoh#47)#52 AS qoh#55] (37) ReusedExchange [Reuses operator id: 23] -Output [6]: [i_product_name#60, i_brand#61, i_class#62, i_category#63, sum#64, count#65] +Output [6]: [i_product_name#56, i_brand#57, i_class#58, i_category#59, sum#60, count#61] (38) HashAggregate [codegen id : 34] -Input [6]: [i_product_name#60, i_brand#61, i_class#62, i_category#63, sum#64, count#65] -Keys [4]: [i_product_name#60, i_brand#61, i_class#62, i_category#63] -Functions [1]: [avg(inv_quantity_on_hand#66)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#66)#17] -Results [2]: [i_product_name#60, avg(inv_quantity_on_hand#66)#17 AS qoh#67] +Input [6]: [i_product_name#56, i_brand#57, i_class#58, i_category#59, sum#60, count#61] +Keys [4]: [i_product_name#56, i_brand#57, i_class#58, i_category#59] +Functions [1]: [avg(inv_quantity_on_hand#62)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#62)#17] +Results [2]: [i_product_name#56, avg(inv_quantity_on_hand#62)#17 AS qoh#63] (39) HashAggregate [codegen id : 34] -Input [2]: [i_product_name#60, qoh#67] -Keys [1]: [i_product_name#60] -Functions [1]: [partial_avg(qoh#67)] -Aggregate Attributes [2]: [sum#68, count#69] -Results [3]: [i_product_name#60, sum#70, count#71] +Input [2]: [i_product_name#56, qoh#63] +Keys [1]: [i_product_name#56] +Functions [1]: [partial_avg(qoh#63)] +Aggregate Attributes [2]: [sum#64, count#65] +Results [3]: [i_product_name#56, sum#66, count#67] (40) Exchange -Input [3]: [i_product_name#60, sum#70, count#71] -Arguments: hashpartitioning(i_product_name#60, 5), ENSURE_REQUIREMENTS, [plan_id=7] +Input [3]: [i_product_name#56, sum#66, count#67] +Arguments: hashpartitioning(i_product_name#56, 5), ENSURE_REQUIREMENTS, [plan_id=7] (41) HashAggregate [codegen id : 35] -Input [3]: [i_product_name#60, sum#70, count#71] -Keys [1]: [i_product_name#60] -Functions [1]: [avg(qoh#67)] -Aggregate Attributes [1]: [avg(qoh#67)#72] -Results [5]: [i_product_name#60, null AS i_brand#73, null AS i_class#74, null AS i_category#75, avg(qoh#67)#72 AS qoh#76] +Input [3]: [i_product_name#56, sum#66, count#67] +Keys [1]: [i_product_name#56] +Functions [1]: [avg(qoh#63)] +Aggregate Attributes [1]: [avg(qoh#63)#68] +Results [5]: [i_product_name#56, null AS i_brand#69, null AS i_class#70, null AS i_category#71, avg(qoh#63)#68 AS qoh#72] (42) ReusedExchange [Reuses operator id: 23] -Output [6]: [i_product_name#77, i_brand#78, i_class#79, i_category#80, sum#81, count#82] +Output [6]: [i_product_name#73, i_brand#74, i_class#75, i_category#76, sum#77, count#78] (43) HashAggregate [codegen id : 43] -Input [6]: [i_product_name#77, i_brand#78, i_class#79, i_category#80, sum#81, count#82] -Keys [4]: [i_product_name#77, i_brand#78, i_class#79, i_category#80] -Functions [1]: [avg(inv_quantity_on_hand#83)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#83)#17] -Results [1]: [avg(inv_quantity_on_hand#83)#17 AS qoh#84] +Input [6]: [i_product_name#73, i_brand#74, i_class#75, i_category#76, sum#77, count#78] +Keys [4]: [i_product_name#73, i_brand#74, i_class#75, i_category#76] +Functions [1]: [avg(inv_quantity_on_hand#79)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#79)#17] +Results [1]: [avg(inv_quantity_on_hand#79)#17 AS qoh#80] (44) HashAggregate [codegen id : 43] -Input [1]: [qoh#84] +Input [1]: [qoh#80] Keys: [] -Functions [1]: [partial_avg(qoh#84)] -Aggregate Attributes [2]: [sum#85, count#86] -Results [2]: [sum#87, count#88] +Functions [1]: [partial_avg(qoh#80)] +Aggregate Attributes [2]: [sum#81, count#82] +Results [2]: [sum#83, count#84] (45) Exchange -Input [2]: [sum#87, count#88] +Input [2]: [sum#83, count#84] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [plan_id=8] (46) HashAggregate [codegen id : 44] -Input [2]: [sum#87, count#88] +Input [2]: [sum#83, count#84] Keys: [] -Functions [1]: [avg(qoh#84)] -Aggregate Attributes [1]: [avg(qoh#84)#89] -Results [5]: [null AS i_product_name#90, null AS i_brand#91, null AS i_class#92, null AS i_category#93, avg(qoh#84)#89 AS qoh#94] +Functions [1]: [avg(qoh#80)] +Aggregate Attributes [1]: [avg(qoh#80)#85] +Results [5]: [null AS i_product_name#86, null AS i_brand#87, null AS i_class#88, null AS i_category#89, avg(qoh#80)#85 AS qoh#90] (47) Union (48) TakeOrderedAndProject -Input [5]: [i_product_name#24, i_brand#25, i_class#26, i_category#27, qoh#28] -Arguments: 100, [qoh#28 ASC NULLS FIRST, i_product_name#24 ASC NULLS FIRST, i_brand#25 ASC NULLS FIRST, i_class#26 ASC NULLS FIRST, i_category#27 ASC NULLS FIRST], [i_product_name#24, i_brand#25, i_class#26, i_category#27, qoh#28] +Input [5]: [i_product_name#12, i_brand#9, i_class#10, i_category#11, qoh#24] +Arguments: 100, [qoh#24 ASC NULLS FIRST, i_product_name#12 ASC NULLS FIRST, i_brand#9 ASC NULLS FIRST, i_class#10 ASC NULLS FIRST, i_category#11 ASC NULLS FIRST], [i_product_name#12, i_brand#9, i_class#10, i_category#11, qoh#24] ===== Subqueries ===== @@ -306,22 +306,22 @@ BroadcastExchange (53) (49) Scan parquet spark_catalog.default.date_dim -Output [2]: [d_date_sk#7, d_month_seq#95] +Output [2]: [d_date_sk#7, d_month_seq#91] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,1212), LessThanOrEqual(d_month_seq,1223), IsNotNull(d_date_sk)] ReadSchema: struct (50) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#7, d_month_seq#95] +Input [2]: [d_date_sk#7, d_month_seq#91] (51) Filter [codegen id : 1] -Input [2]: [d_date_sk#7, d_month_seq#95] -Condition : (((isnotnull(d_month_seq#95) AND (d_month_seq#95 >= 1212)) AND (d_month_seq#95 <= 1223)) AND isnotnull(d_date_sk#7)) +Input [2]: [d_date_sk#7, d_month_seq#91] +Condition : (((isnotnull(d_month_seq#91) AND (d_month_seq#91 >= 1212)) AND (d_month_seq#91 <= 1223)) AND isnotnull(d_date_sk#7)) (52) Project [codegen id : 1] Output [1]: [d_date_sk#7] -Input [2]: [d_date_sk#7, d_month_seq#95] +Input [2]: [d_date_sk#7, d_month_seq#91] (53) BroadcastExchange Input [1]: [d_date_sk#7] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/simplified.txt index 042f946b8fca4..0c4267b3ca513 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/simplified.txt @@ -1,7 +1,7 @@ TakeOrderedAndProject [qoh,i_product_name,i_brand,i_class,i_category] Union WholeStageCodegen (8) - HashAggregate [i_product_name,i_brand,i_class,i_category,sum,count] [avg(qoh),i_product_name,i_brand,i_class,i_category,qoh,sum,count] + HashAggregate [i_product_name,i_brand,i_class,i_category,sum,count] [avg(qoh),qoh,sum,count] HashAggregate [i_product_name,i_brand,i_class,i_category,qoh] [sum,count,sum,count] HashAggregate [i_product_name,i_brand,i_class,i_category,sum,count] [avg(inv_quantity_on_hand),qoh,sum,count] InputAdapter diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/explain.txt index 8aab8e91acfc8..4b8993f370f4d 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/explain.txt @@ -160,125 +160,125 @@ Input [6]: [i_product_name#11, i_brand#8, i_class#9, i_category#10, sum#21, coun Keys [4]: [i_product_name#11, i_brand#8, i_class#9, i_category#10] Functions [1]: [avg(qoh#18)] Aggregate Attributes [1]: [avg(qoh#18)#23] -Results [5]: [i_product_name#11 AS i_product_name#24, i_brand#8 AS i_brand#25, i_class#9 AS i_class#26, i_category#10 AS i_category#27, avg(qoh#18)#23 AS qoh#28] +Results [5]: [i_product_name#11, i_brand#8, i_class#9, i_category#10, avg(qoh#18)#23 AS qoh#24] (24) ReusedExchange [Reuses operator id: 20] -Output [6]: [i_product_name#29, i_brand#30, i_class#31, i_category#32, sum#33, count#34] +Output [6]: [i_product_name#25, i_brand#26, i_class#27, i_category#28, sum#29, count#30] (25) HashAggregate [codegen id : 10] -Input [6]: [i_product_name#29, i_brand#30, i_class#31, i_category#32, sum#33, count#34] -Keys [4]: [i_product_name#29, i_brand#30, i_class#31, i_category#32] -Functions [1]: [avg(inv_quantity_on_hand#35)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#35)#17] -Results [4]: [i_product_name#29, i_brand#30, i_class#31, avg(inv_quantity_on_hand#35)#17 AS qoh#36] +Input [6]: [i_product_name#25, i_brand#26, i_class#27, i_category#28, sum#29, count#30] +Keys [4]: [i_product_name#25, i_brand#26, i_class#27, i_category#28] +Functions [1]: [avg(inv_quantity_on_hand#31)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#31)#17] +Results [4]: [i_product_name#25, i_brand#26, i_class#27, avg(inv_quantity_on_hand#31)#17 AS qoh#32] (26) HashAggregate [codegen id : 10] -Input [4]: [i_product_name#29, i_brand#30, i_class#31, qoh#36] -Keys [3]: [i_product_name#29, i_brand#30, i_class#31] -Functions [1]: [partial_avg(qoh#36)] -Aggregate Attributes [2]: [sum#37, count#38] -Results [5]: [i_product_name#29, i_brand#30, i_class#31, sum#39, count#40] +Input [4]: [i_product_name#25, i_brand#26, i_class#27, qoh#32] +Keys [3]: [i_product_name#25, i_brand#26, i_class#27] +Functions [1]: [partial_avg(qoh#32)] +Aggregate Attributes [2]: [sum#33, count#34] +Results [5]: [i_product_name#25, i_brand#26, i_class#27, sum#35, count#36] (27) Exchange -Input [5]: [i_product_name#29, i_brand#30, i_class#31, sum#39, count#40] -Arguments: hashpartitioning(i_product_name#29, i_brand#30, i_class#31, 5), ENSURE_REQUIREMENTS, [plan_id=4] +Input [5]: [i_product_name#25, i_brand#26, i_class#27, sum#35, count#36] +Arguments: hashpartitioning(i_product_name#25, i_brand#26, i_class#27, 5), ENSURE_REQUIREMENTS, [plan_id=4] (28) HashAggregate [codegen id : 11] -Input [5]: [i_product_name#29, i_brand#30, i_class#31, sum#39, count#40] -Keys [3]: [i_product_name#29, i_brand#30, i_class#31] -Functions [1]: [avg(qoh#36)] -Aggregate Attributes [1]: [avg(qoh#36)#41] -Results [5]: [i_product_name#29, i_brand#30, i_class#31, null AS i_category#42, avg(qoh#36)#41 AS qoh#43] +Input [5]: [i_product_name#25, i_brand#26, i_class#27, sum#35, count#36] +Keys [3]: [i_product_name#25, i_brand#26, i_class#27] +Functions [1]: [avg(qoh#32)] +Aggregate Attributes [1]: [avg(qoh#32)#37] +Results [5]: [i_product_name#25, i_brand#26, i_class#27, null AS i_category#38, avg(qoh#32)#37 AS qoh#39] (29) ReusedExchange [Reuses operator id: 20] -Output [6]: [i_product_name#44, i_brand#45, i_class#46, i_category#47, sum#48, count#49] +Output [6]: [i_product_name#40, i_brand#41, i_class#42, i_category#43, sum#44, count#45] (30) HashAggregate [codegen id : 16] -Input [6]: [i_product_name#44, i_brand#45, i_class#46, i_category#47, sum#48, count#49] -Keys [4]: [i_product_name#44, i_brand#45, i_class#46, i_category#47] -Functions [1]: [avg(inv_quantity_on_hand#50)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#50)#17] -Results [3]: [i_product_name#44, i_brand#45, avg(inv_quantity_on_hand#50)#17 AS qoh#51] +Input [6]: [i_product_name#40, i_brand#41, i_class#42, i_category#43, sum#44, count#45] +Keys [4]: [i_product_name#40, i_brand#41, i_class#42, i_category#43] +Functions [1]: [avg(inv_quantity_on_hand#46)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#46)#17] +Results [3]: [i_product_name#40, i_brand#41, avg(inv_quantity_on_hand#46)#17 AS qoh#47] (31) HashAggregate [codegen id : 16] -Input [3]: [i_product_name#44, i_brand#45, qoh#51] -Keys [2]: [i_product_name#44, i_brand#45] -Functions [1]: [partial_avg(qoh#51)] -Aggregate Attributes [2]: [sum#52, count#53] -Results [4]: [i_product_name#44, i_brand#45, sum#54, count#55] +Input [3]: [i_product_name#40, i_brand#41, qoh#47] +Keys [2]: [i_product_name#40, i_brand#41] +Functions [1]: [partial_avg(qoh#47)] +Aggregate Attributes [2]: [sum#48, count#49] +Results [4]: [i_product_name#40, i_brand#41, sum#50, count#51] (32) Exchange -Input [4]: [i_product_name#44, i_brand#45, sum#54, count#55] -Arguments: hashpartitioning(i_product_name#44, i_brand#45, 5), ENSURE_REQUIREMENTS, [plan_id=5] +Input [4]: [i_product_name#40, i_brand#41, sum#50, count#51] +Arguments: hashpartitioning(i_product_name#40, i_brand#41, 5), ENSURE_REQUIREMENTS, [plan_id=5] (33) HashAggregate [codegen id : 17] -Input [4]: [i_product_name#44, i_brand#45, sum#54, count#55] -Keys [2]: [i_product_name#44, i_brand#45] -Functions [1]: [avg(qoh#51)] -Aggregate Attributes [1]: [avg(qoh#51)#56] -Results [5]: [i_product_name#44, i_brand#45, null AS i_class#57, null AS i_category#58, avg(qoh#51)#56 AS qoh#59] +Input [4]: [i_product_name#40, i_brand#41, sum#50, count#51] +Keys [2]: [i_product_name#40, i_brand#41] +Functions [1]: [avg(qoh#47)] +Aggregate Attributes [1]: [avg(qoh#47)#52] +Results [5]: [i_product_name#40, i_brand#41, null AS i_class#53, null AS i_category#54, avg(qoh#47)#52 AS qoh#55] (34) ReusedExchange [Reuses operator id: 20] -Output [6]: [i_product_name#60, i_brand#61, i_class#62, i_category#63, sum#64, count#65] +Output [6]: [i_product_name#56, i_brand#57, i_class#58, i_category#59, sum#60, count#61] (35) HashAggregate [codegen id : 22] -Input [6]: [i_product_name#60, i_brand#61, i_class#62, i_category#63, sum#64, count#65] -Keys [4]: [i_product_name#60, i_brand#61, i_class#62, i_category#63] -Functions [1]: [avg(inv_quantity_on_hand#66)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#66)#17] -Results [2]: [i_product_name#60, avg(inv_quantity_on_hand#66)#17 AS qoh#67] +Input [6]: [i_product_name#56, i_brand#57, i_class#58, i_category#59, sum#60, count#61] +Keys [4]: [i_product_name#56, i_brand#57, i_class#58, i_category#59] +Functions [1]: [avg(inv_quantity_on_hand#62)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#62)#17] +Results [2]: [i_product_name#56, avg(inv_quantity_on_hand#62)#17 AS qoh#63] (36) HashAggregate [codegen id : 22] -Input [2]: [i_product_name#60, qoh#67] -Keys [1]: [i_product_name#60] -Functions [1]: [partial_avg(qoh#67)] -Aggregate Attributes [2]: [sum#68, count#69] -Results [3]: [i_product_name#60, sum#70, count#71] +Input [2]: [i_product_name#56, qoh#63] +Keys [1]: [i_product_name#56] +Functions [1]: [partial_avg(qoh#63)] +Aggregate Attributes [2]: [sum#64, count#65] +Results [3]: [i_product_name#56, sum#66, count#67] (37) Exchange -Input [3]: [i_product_name#60, sum#70, count#71] -Arguments: hashpartitioning(i_product_name#60, 5), ENSURE_REQUIREMENTS, [plan_id=6] +Input [3]: [i_product_name#56, sum#66, count#67] +Arguments: hashpartitioning(i_product_name#56, 5), ENSURE_REQUIREMENTS, [plan_id=6] (38) HashAggregate [codegen id : 23] -Input [3]: [i_product_name#60, sum#70, count#71] -Keys [1]: [i_product_name#60] -Functions [1]: [avg(qoh#67)] -Aggregate Attributes [1]: [avg(qoh#67)#72] -Results [5]: [i_product_name#60, null AS i_brand#73, null AS i_class#74, null AS i_category#75, avg(qoh#67)#72 AS qoh#76] +Input [3]: [i_product_name#56, sum#66, count#67] +Keys [1]: [i_product_name#56] +Functions [1]: [avg(qoh#63)] +Aggregate Attributes [1]: [avg(qoh#63)#68] +Results [5]: [i_product_name#56, null AS i_brand#69, null AS i_class#70, null AS i_category#71, avg(qoh#63)#68 AS qoh#72] (39) ReusedExchange [Reuses operator id: 20] -Output [6]: [i_product_name#77, i_brand#78, i_class#79, i_category#80, sum#81, count#82] +Output [6]: [i_product_name#73, i_brand#74, i_class#75, i_category#76, sum#77, count#78] (40) HashAggregate [codegen id : 28] -Input [6]: [i_product_name#77, i_brand#78, i_class#79, i_category#80, sum#81, count#82] -Keys [4]: [i_product_name#77, i_brand#78, i_class#79, i_category#80] -Functions [1]: [avg(inv_quantity_on_hand#83)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#83)#17] -Results [1]: [avg(inv_quantity_on_hand#83)#17 AS qoh#84] +Input [6]: [i_product_name#73, i_brand#74, i_class#75, i_category#76, sum#77, count#78] +Keys [4]: [i_product_name#73, i_brand#74, i_class#75, i_category#76] +Functions [1]: [avg(inv_quantity_on_hand#79)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#79)#17] +Results [1]: [avg(inv_quantity_on_hand#79)#17 AS qoh#80] (41) HashAggregate [codegen id : 28] -Input [1]: [qoh#84] +Input [1]: [qoh#80] Keys: [] -Functions [1]: [partial_avg(qoh#84)] -Aggregate Attributes [2]: [sum#85, count#86] -Results [2]: [sum#87, count#88] +Functions [1]: [partial_avg(qoh#80)] +Aggregate Attributes [2]: [sum#81, count#82] +Results [2]: [sum#83, count#84] (42) Exchange -Input [2]: [sum#87, count#88] +Input [2]: [sum#83, count#84] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [plan_id=7] (43) HashAggregate [codegen id : 29] -Input [2]: [sum#87, count#88] +Input [2]: [sum#83, count#84] Keys: [] -Functions [1]: [avg(qoh#84)] -Aggregate Attributes [1]: [avg(qoh#84)#89] -Results [5]: [null AS i_product_name#90, null AS i_brand#91, null AS i_class#92, null AS i_category#93, avg(qoh#84)#89 AS qoh#94] +Functions [1]: [avg(qoh#80)] +Aggregate Attributes [1]: [avg(qoh#80)#85] +Results [5]: [null AS i_product_name#86, null AS i_brand#87, null AS i_class#88, null AS i_category#89, avg(qoh#80)#85 AS qoh#90] (44) Union (45) TakeOrderedAndProject -Input [5]: [i_product_name#24, i_brand#25, i_class#26, i_category#27, qoh#28] -Arguments: 100, [qoh#28 ASC NULLS FIRST, i_product_name#24 ASC NULLS FIRST, i_brand#25 ASC NULLS FIRST, i_class#26 ASC NULLS FIRST, i_category#27 ASC NULLS FIRST], [i_product_name#24, i_brand#25, i_class#26, i_category#27, qoh#28] +Input [5]: [i_product_name#11, i_brand#8, i_class#9, i_category#10, qoh#24] +Arguments: 100, [qoh#24 ASC NULLS FIRST, i_product_name#11 ASC NULLS FIRST, i_brand#8 ASC NULLS FIRST, i_class#9 ASC NULLS FIRST, i_category#10 ASC NULLS FIRST], [i_product_name#11, i_brand#8, i_class#9, i_category#10, qoh#24] ===== Subqueries ===== @@ -291,22 +291,22 @@ BroadcastExchange (50) (46) Scan parquet spark_catalog.default.date_dim -Output [2]: [d_date_sk#6, d_month_seq#95] +Output [2]: [d_date_sk#6, d_month_seq#91] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,1212), LessThanOrEqual(d_month_seq,1223), IsNotNull(d_date_sk)] ReadSchema: struct (47) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#6, d_month_seq#95] +Input [2]: [d_date_sk#6, d_month_seq#91] (48) Filter [codegen id : 1] -Input [2]: [d_date_sk#6, d_month_seq#95] -Condition : (((isnotnull(d_month_seq#95) AND (d_month_seq#95 >= 1212)) AND (d_month_seq#95 <= 1223)) AND isnotnull(d_date_sk#6)) +Input [2]: [d_date_sk#6, d_month_seq#91] +Condition : (((isnotnull(d_month_seq#91) AND (d_month_seq#91 >= 1212)) AND (d_month_seq#91 <= 1223)) AND isnotnull(d_date_sk#6)) (49) Project [codegen id : 1] Output [1]: [d_date_sk#6] -Input [2]: [d_date_sk#6, d_month_seq#95] +Input [2]: [d_date_sk#6, d_month_seq#91] (50) BroadcastExchange Input [1]: [d_date_sk#6] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/simplified.txt index d747066f5945b..22f73cc9b9db5 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/simplified.txt @@ -1,7 +1,7 @@ TakeOrderedAndProject [qoh,i_product_name,i_brand,i_class,i_category] Union WholeStageCodegen (5) - HashAggregate [i_product_name,i_brand,i_class,i_category,sum,count] [avg(qoh),i_product_name,i_brand,i_class,i_category,qoh,sum,count] + HashAggregate [i_product_name,i_brand,i_class,i_category,sum,count] [avg(qoh),qoh,sum,count] HashAggregate [i_product_name,i_brand,i_class,i_category,qoh] [sum,count,sum,count] HashAggregate [i_product_name,i_brand,i_class,i_category,sum,count] [avg(inv_quantity_on_hand),qoh,sum,count] InputAdapter diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/explain.txt index a4c009f8219b4..9c28ff9f351d8 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/explain.txt @@ -186,265 +186,265 @@ Input [10]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, Keys [8]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] Functions [1]: [sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))] Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))#22] -Results [9]: [i_category#16 AS i_category#23, i_class#15 AS i_class#24, i_brand#14 AS i_brand#25, i_product_name#17 AS i_product_name#26, d_year#8 AS d_year#27, d_qoy#10 AS d_qoy#28, d_moy#9 AS d_moy#29, s_store_id#12 AS s_store_id#30, cast(sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))#22 as decimal(38,2)) AS sumsales#31] +Results [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, cast(sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))#22 as decimal(38,2)) AS sumsales#23] (25) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, s_store_id#39, sum#40, isEmpty#41] +Output [10]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, s_store_id#31, sum#32, isEmpty#33] (26) HashAggregate [codegen id : 16] -Input [10]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, s_store_id#39, sum#40, isEmpty#41] -Keys [8]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, s_store_id#39] -Functions [1]: [sum(coalesce((ss_sales_price#42 * cast(ss_quantity#43 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#42 * cast(ss_quantity#43 as decimal(10,0))), 0.00))#22] -Results [8]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum(coalesce((ss_sales_price#42 * cast(ss_quantity#43 as decimal(10,0))), 0.00))#22 AS sumsales#44] +Input [10]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, s_store_id#31, sum#32, isEmpty#33] +Keys [8]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, s_store_id#31] +Functions [1]: [sum(coalesce((ss_sales_price#34 * cast(ss_quantity#35 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#34 * cast(ss_quantity#35 as decimal(10,0))), 0.00))#22] +Results [8]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum(coalesce((ss_sales_price#34 * cast(ss_quantity#35 as decimal(10,0))), 0.00))#22 AS sumsales#36] (27) HashAggregate [codegen id : 16] -Input [8]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sumsales#44] -Keys [7]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38] -Functions [1]: [partial_sum(sumsales#44)] -Aggregate Attributes [2]: [sum#45, isEmpty#46] -Results [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum#47, isEmpty#48] +Input [8]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sumsales#36] +Keys [7]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30] +Functions [1]: [partial_sum(sumsales#36)] +Aggregate Attributes [2]: [sum#37, isEmpty#38] +Results [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum#39, isEmpty#40] (28) Exchange -Input [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum#47, isEmpty#48] -Arguments: hashpartitioning(i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, 5), ENSURE_REQUIREMENTS, [plan_id=5] +Input [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum#39, isEmpty#40] +Arguments: hashpartitioning(i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, 5), ENSURE_REQUIREMENTS, [plan_id=5] (29) HashAggregate [codegen id : 17] -Input [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum#47, isEmpty#48] -Keys [7]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38] -Functions [1]: [sum(sumsales#44)] -Aggregate Attributes [1]: [sum(sumsales#44)#49] -Results [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, null AS s_store_id#50, sum(sumsales#44)#49 AS sumsales#51] +Input [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum#39, isEmpty#40] +Keys [7]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30] +Functions [1]: [sum(sumsales#36)] +Aggregate Attributes [1]: [sum(sumsales#36)#41] +Results [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, null AS s_store_id#42, sum(sumsales#36)#41 AS sumsales#43] (30) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, d_moy#58, s_store_id#59, sum#60, isEmpty#61] +Output [10]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, d_moy#50, s_store_id#51, sum#52, isEmpty#53] (31) HashAggregate [codegen id : 25] -Input [10]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, d_moy#58, s_store_id#59, sum#60, isEmpty#61] -Keys [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, d_moy#58, s_store_id#59] -Functions [1]: [sum(coalesce((ss_sales_price#62 * cast(ss_quantity#63 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#62 * cast(ss_quantity#63 as decimal(10,0))), 0.00))#22] -Results [7]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum(coalesce((ss_sales_price#62 * cast(ss_quantity#63 as decimal(10,0))), 0.00))#22 AS sumsales#64] +Input [10]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, d_moy#50, s_store_id#51, sum#52, isEmpty#53] +Keys [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, d_moy#50, s_store_id#51] +Functions [1]: [sum(coalesce((ss_sales_price#54 * cast(ss_quantity#55 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#54 * cast(ss_quantity#55 as decimal(10,0))), 0.00))#22] +Results [7]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum(coalesce((ss_sales_price#54 * cast(ss_quantity#55 as decimal(10,0))), 0.00))#22 AS sumsales#56] (32) HashAggregate [codegen id : 25] -Input [7]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sumsales#64] -Keys [6]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57] -Functions [1]: [partial_sum(sumsales#64)] -Aggregate Attributes [2]: [sum#65, isEmpty#66] -Results [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum#67, isEmpty#68] +Input [7]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sumsales#56] +Keys [6]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49] +Functions [1]: [partial_sum(sumsales#56)] +Aggregate Attributes [2]: [sum#57, isEmpty#58] +Results [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum#59, isEmpty#60] (33) Exchange -Input [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum#67, isEmpty#68] -Arguments: hashpartitioning(i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, 5), ENSURE_REQUIREMENTS, [plan_id=6] +Input [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum#59, isEmpty#60] +Arguments: hashpartitioning(i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, 5), ENSURE_REQUIREMENTS, [plan_id=6] (34) HashAggregate [codegen id : 26] -Input [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum#67, isEmpty#68] -Keys [6]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57] -Functions [1]: [sum(sumsales#64)] -Aggregate Attributes [1]: [sum(sumsales#64)#69] -Results [9]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, null AS d_moy#70, null AS s_store_id#71, sum(sumsales#64)#69 AS sumsales#72] +Input [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum#59, isEmpty#60] +Keys [6]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49] +Functions [1]: [sum(sumsales#56)] +Aggregate Attributes [1]: [sum(sumsales#56)#61] +Results [9]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, null AS d_moy#62, null AS s_store_id#63, sum(sumsales#56)#61 AS sumsales#64] (35) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, d_qoy#78, d_moy#79, s_store_id#80, sum#81, isEmpty#82] +Output [10]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, d_qoy#70, d_moy#71, s_store_id#72, sum#73, isEmpty#74] (36) HashAggregate [codegen id : 34] -Input [10]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, d_qoy#78, d_moy#79, s_store_id#80, sum#81, isEmpty#82] -Keys [8]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, d_qoy#78, d_moy#79, s_store_id#80] -Functions [1]: [sum(coalesce((ss_sales_price#83 * cast(ss_quantity#84 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#83 * cast(ss_quantity#84 as decimal(10,0))), 0.00))#22] -Results [6]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum(coalesce((ss_sales_price#83 * cast(ss_quantity#84 as decimal(10,0))), 0.00))#22 AS sumsales#85] +Input [10]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, d_qoy#70, d_moy#71, s_store_id#72, sum#73, isEmpty#74] +Keys [8]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, d_qoy#70, d_moy#71, s_store_id#72] +Functions [1]: [sum(coalesce((ss_sales_price#75 * cast(ss_quantity#76 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#75 * cast(ss_quantity#76 as decimal(10,0))), 0.00))#22] +Results [6]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum(coalesce((ss_sales_price#75 * cast(ss_quantity#76 as decimal(10,0))), 0.00))#22 AS sumsales#77] (37) HashAggregate [codegen id : 34] -Input [6]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sumsales#85] -Keys [5]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77] -Functions [1]: [partial_sum(sumsales#85)] -Aggregate Attributes [2]: [sum#86, isEmpty#87] -Results [7]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum#88, isEmpty#89] +Input [6]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sumsales#77] +Keys [5]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69] +Functions [1]: [partial_sum(sumsales#77)] +Aggregate Attributes [2]: [sum#78, isEmpty#79] +Results [7]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum#80, isEmpty#81] (38) Exchange -Input [7]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum#88, isEmpty#89] -Arguments: hashpartitioning(i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, 5), ENSURE_REQUIREMENTS, [plan_id=7] +Input [7]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum#80, isEmpty#81] +Arguments: hashpartitioning(i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, 5), ENSURE_REQUIREMENTS, [plan_id=7] (39) HashAggregate [codegen id : 35] -Input [7]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum#88, isEmpty#89] -Keys [5]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77] -Functions [1]: [sum(sumsales#85)] -Aggregate Attributes [1]: [sum(sumsales#85)#90] -Results [9]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, null AS d_qoy#91, null AS d_moy#92, null AS s_store_id#93, sum(sumsales#85)#90 AS sumsales#94] +Input [7]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum#80, isEmpty#81] +Keys [5]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69] +Functions [1]: [sum(sumsales#77)] +Aggregate Attributes [1]: [sum(sumsales#77)#82] +Results [9]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, null AS d_qoy#83, null AS d_moy#84, null AS s_store_id#85, sum(sumsales#77)#82 AS sumsales#86] (40) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, d_year#99, d_qoy#100, d_moy#101, s_store_id#102, sum#103, isEmpty#104] +Output [10]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, d_year#91, d_qoy#92, d_moy#93, s_store_id#94, sum#95, isEmpty#96] (41) HashAggregate [codegen id : 43] -Input [10]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, d_year#99, d_qoy#100, d_moy#101, s_store_id#102, sum#103, isEmpty#104] -Keys [8]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, d_year#99, d_qoy#100, d_moy#101, s_store_id#102] -Functions [1]: [sum(coalesce((ss_sales_price#105 * cast(ss_quantity#106 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#105 * cast(ss_quantity#106 as decimal(10,0))), 0.00))#22] -Results [5]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum(coalesce((ss_sales_price#105 * cast(ss_quantity#106 as decimal(10,0))), 0.00))#22 AS sumsales#107] +Input [10]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, d_year#91, d_qoy#92, d_moy#93, s_store_id#94, sum#95, isEmpty#96] +Keys [8]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, d_year#91, d_qoy#92, d_moy#93, s_store_id#94] +Functions [1]: [sum(coalesce((ss_sales_price#97 * cast(ss_quantity#98 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#97 * cast(ss_quantity#98 as decimal(10,0))), 0.00))#22] +Results [5]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum(coalesce((ss_sales_price#97 * cast(ss_quantity#98 as decimal(10,0))), 0.00))#22 AS sumsales#99] (42) HashAggregate [codegen id : 43] -Input [5]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sumsales#107] -Keys [4]: [i_category#95, i_class#96, i_brand#97, i_product_name#98] -Functions [1]: [partial_sum(sumsales#107)] -Aggregate Attributes [2]: [sum#108, isEmpty#109] -Results [6]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum#110, isEmpty#111] +Input [5]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sumsales#99] +Keys [4]: [i_category#87, i_class#88, i_brand#89, i_product_name#90] +Functions [1]: [partial_sum(sumsales#99)] +Aggregate Attributes [2]: [sum#100, isEmpty#101] +Results [6]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum#102, isEmpty#103] (43) Exchange -Input [6]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum#110, isEmpty#111] -Arguments: hashpartitioning(i_category#95, i_class#96, i_brand#97, i_product_name#98, 5), ENSURE_REQUIREMENTS, [plan_id=8] +Input [6]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum#102, isEmpty#103] +Arguments: hashpartitioning(i_category#87, i_class#88, i_brand#89, i_product_name#90, 5), ENSURE_REQUIREMENTS, [plan_id=8] (44) HashAggregate [codegen id : 44] -Input [6]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum#110, isEmpty#111] -Keys [4]: [i_category#95, i_class#96, i_brand#97, i_product_name#98] -Functions [1]: [sum(sumsales#107)] -Aggregate Attributes [1]: [sum(sumsales#107)#112] -Results [9]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, null AS d_year#113, null AS d_qoy#114, null AS d_moy#115, null AS s_store_id#116, sum(sumsales#107)#112 AS sumsales#117] +Input [6]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum#102, isEmpty#103] +Keys [4]: [i_category#87, i_class#88, i_brand#89, i_product_name#90] +Functions [1]: [sum(sumsales#99)] +Aggregate Attributes [1]: [sum(sumsales#99)#104] +Results [9]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, null AS d_year#105, null AS d_qoy#106, null AS d_moy#107, null AS s_store_id#108, sum(sumsales#99)#104 AS sumsales#109] (45) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#118, i_class#119, i_brand#120, i_product_name#121, d_year#122, d_qoy#123, d_moy#124, s_store_id#125, sum#126, isEmpty#127] +Output [10]: [i_category#110, i_class#111, i_brand#112, i_product_name#113, d_year#114, d_qoy#115, d_moy#116, s_store_id#117, sum#118, isEmpty#119] (46) HashAggregate [codegen id : 52] -Input [10]: [i_category#118, i_class#119, i_brand#120, i_product_name#121, d_year#122, d_qoy#123, d_moy#124, s_store_id#125, sum#126, isEmpty#127] -Keys [8]: [i_category#118, i_class#119, i_brand#120, i_product_name#121, d_year#122, d_qoy#123, d_moy#124, s_store_id#125] -Functions [1]: [sum(coalesce((ss_sales_price#128 * cast(ss_quantity#129 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#128 * cast(ss_quantity#129 as decimal(10,0))), 0.00))#22] -Results [4]: [i_category#118, i_class#119, i_brand#120, sum(coalesce((ss_sales_price#128 * cast(ss_quantity#129 as decimal(10,0))), 0.00))#22 AS sumsales#130] +Input [10]: [i_category#110, i_class#111, i_brand#112, i_product_name#113, d_year#114, d_qoy#115, d_moy#116, s_store_id#117, sum#118, isEmpty#119] +Keys [8]: [i_category#110, i_class#111, i_brand#112, i_product_name#113, d_year#114, d_qoy#115, d_moy#116, s_store_id#117] +Functions [1]: [sum(coalesce((ss_sales_price#120 * cast(ss_quantity#121 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#120 * cast(ss_quantity#121 as decimal(10,0))), 0.00))#22] +Results [4]: [i_category#110, i_class#111, i_brand#112, sum(coalesce((ss_sales_price#120 * cast(ss_quantity#121 as decimal(10,0))), 0.00))#22 AS sumsales#122] (47) HashAggregate [codegen id : 52] -Input [4]: [i_category#118, i_class#119, i_brand#120, sumsales#130] -Keys [3]: [i_category#118, i_class#119, i_brand#120] -Functions [1]: [partial_sum(sumsales#130)] -Aggregate Attributes [2]: [sum#131, isEmpty#132] -Results [5]: [i_category#118, i_class#119, i_brand#120, sum#133, isEmpty#134] +Input [4]: [i_category#110, i_class#111, i_brand#112, sumsales#122] +Keys [3]: [i_category#110, i_class#111, i_brand#112] +Functions [1]: [partial_sum(sumsales#122)] +Aggregate Attributes [2]: [sum#123, isEmpty#124] +Results [5]: [i_category#110, i_class#111, i_brand#112, sum#125, isEmpty#126] (48) Exchange -Input [5]: [i_category#118, i_class#119, i_brand#120, sum#133, isEmpty#134] -Arguments: hashpartitioning(i_category#118, i_class#119, i_brand#120, 5), ENSURE_REQUIREMENTS, [plan_id=9] +Input [5]: [i_category#110, i_class#111, i_brand#112, sum#125, isEmpty#126] +Arguments: hashpartitioning(i_category#110, i_class#111, i_brand#112, 5), ENSURE_REQUIREMENTS, [plan_id=9] (49) HashAggregate [codegen id : 53] -Input [5]: [i_category#118, i_class#119, i_brand#120, sum#133, isEmpty#134] -Keys [3]: [i_category#118, i_class#119, i_brand#120] -Functions [1]: [sum(sumsales#130)] -Aggregate Attributes [1]: [sum(sumsales#130)#135] -Results [9]: [i_category#118, i_class#119, i_brand#120, null AS i_product_name#136, null AS d_year#137, null AS d_qoy#138, null AS d_moy#139, null AS s_store_id#140, sum(sumsales#130)#135 AS sumsales#141] +Input [5]: [i_category#110, i_class#111, i_brand#112, sum#125, isEmpty#126] +Keys [3]: [i_category#110, i_class#111, i_brand#112] +Functions [1]: [sum(sumsales#122)] +Aggregate Attributes [1]: [sum(sumsales#122)#127] +Results [9]: [i_category#110, i_class#111, i_brand#112, null AS i_product_name#128, null AS d_year#129, null AS d_qoy#130, null AS d_moy#131, null AS s_store_id#132, sum(sumsales#122)#127 AS sumsales#133] (50) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#142, i_class#143, i_brand#144, i_product_name#145, d_year#146, d_qoy#147, d_moy#148, s_store_id#149, sum#150, isEmpty#151] +Output [10]: [i_category#134, i_class#135, i_brand#136, i_product_name#137, d_year#138, d_qoy#139, d_moy#140, s_store_id#141, sum#142, isEmpty#143] (51) HashAggregate [codegen id : 61] -Input [10]: [i_category#142, i_class#143, i_brand#144, i_product_name#145, d_year#146, d_qoy#147, d_moy#148, s_store_id#149, sum#150, isEmpty#151] -Keys [8]: [i_category#142, i_class#143, i_brand#144, i_product_name#145, d_year#146, d_qoy#147, d_moy#148, s_store_id#149] -Functions [1]: [sum(coalesce((ss_sales_price#152 * cast(ss_quantity#153 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#152 * cast(ss_quantity#153 as decimal(10,0))), 0.00))#22] -Results [3]: [i_category#142, i_class#143, sum(coalesce((ss_sales_price#152 * cast(ss_quantity#153 as decimal(10,0))), 0.00))#22 AS sumsales#154] +Input [10]: [i_category#134, i_class#135, i_brand#136, i_product_name#137, d_year#138, d_qoy#139, d_moy#140, s_store_id#141, sum#142, isEmpty#143] +Keys [8]: [i_category#134, i_class#135, i_brand#136, i_product_name#137, d_year#138, d_qoy#139, d_moy#140, s_store_id#141] +Functions [1]: [sum(coalesce((ss_sales_price#144 * cast(ss_quantity#145 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#144 * cast(ss_quantity#145 as decimal(10,0))), 0.00))#22] +Results [3]: [i_category#134, i_class#135, sum(coalesce((ss_sales_price#144 * cast(ss_quantity#145 as decimal(10,0))), 0.00))#22 AS sumsales#146] (52) HashAggregate [codegen id : 61] -Input [3]: [i_category#142, i_class#143, sumsales#154] -Keys [2]: [i_category#142, i_class#143] -Functions [1]: [partial_sum(sumsales#154)] -Aggregate Attributes [2]: [sum#155, isEmpty#156] -Results [4]: [i_category#142, i_class#143, sum#157, isEmpty#158] +Input [3]: [i_category#134, i_class#135, sumsales#146] +Keys [2]: [i_category#134, i_class#135] +Functions [1]: [partial_sum(sumsales#146)] +Aggregate Attributes [2]: [sum#147, isEmpty#148] +Results [4]: [i_category#134, i_class#135, sum#149, isEmpty#150] (53) Exchange -Input [4]: [i_category#142, i_class#143, sum#157, isEmpty#158] -Arguments: hashpartitioning(i_category#142, i_class#143, 5), ENSURE_REQUIREMENTS, [plan_id=10] +Input [4]: [i_category#134, i_class#135, sum#149, isEmpty#150] +Arguments: hashpartitioning(i_category#134, i_class#135, 5), ENSURE_REQUIREMENTS, [plan_id=10] (54) HashAggregate [codegen id : 62] -Input [4]: [i_category#142, i_class#143, sum#157, isEmpty#158] -Keys [2]: [i_category#142, i_class#143] -Functions [1]: [sum(sumsales#154)] -Aggregate Attributes [1]: [sum(sumsales#154)#159] -Results [9]: [i_category#142, i_class#143, null AS i_brand#160, null AS i_product_name#161, null AS d_year#162, null AS d_qoy#163, null AS d_moy#164, null AS s_store_id#165, sum(sumsales#154)#159 AS sumsales#166] +Input [4]: [i_category#134, i_class#135, sum#149, isEmpty#150] +Keys [2]: [i_category#134, i_class#135] +Functions [1]: [sum(sumsales#146)] +Aggregate Attributes [1]: [sum(sumsales#146)#151] +Results [9]: [i_category#134, i_class#135, null AS i_brand#152, null AS i_product_name#153, null AS d_year#154, null AS d_qoy#155, null AS d_moy#156, null AS s_store_id#157, sum(sumsales#146)#151 AS sumsales#158] (55) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#167, i_class#168, i_brand#169, i_product_name#170, d_year#171, d_qoy#172, d_moy#173, s_store_id#174, sum#175, isEmpty#176] +Output [10]: [i_category#159, i_class#160, i_brand#161, i_product_name#162, d_year#163, d_qoy#164, d_moy#165, s_store_id#166, sum#167, isEmpty#168] (56) HashAggregate [codegen id : 70] -Input [10]: [i_category#167, i_class#168, i_brand#169, i_product_name#170, d_year#171, d_qoy#172, d_moy#173, s_store_id#174, sum#175, isEmpty#176] -Keys [8]: [i_category#167, i_class#168, i_brand#169, i_product_name#170, d_year#171, d_qoy#172, d_moy#173, s_store_id#174] -Functions [1]: [sum(coalesce((ss_sales_price#177 * cast(ss_quantity#178 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#177 * cast(ss_quantity#178 as decimal(10,0))), 0.00))#22] -Results [2]: [i_category#167, sum(coalesce((ss_sales_price#177 * cast(ss_quantity#178 as decimal(10,0))), 0.00))#22 AS sumsales#179] +Input [10]: [i_category#159, i_class#160, i_brand#161, i_product_name#162, d_year#163, d_qoy#164, d_moy#165, s_store_id#166, sum#167, isEmpty#168] +Keys [8]: [i_category#159, i_class#160, i_brand#161, i_product_name#162, d_year#163, d_qoy#164, d_moy#165, s_store_id#166] +Functions [1]: [sum(coalesce((ss_sales_price#169 * cast(ss_quantity#170 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#169 * cast(ss_quantity#170 as decimal(10,0))), 0.00))#22] +Results [2]: [i_category#159, sum(coalesce((ss_sales_price#169 * cast(ss_quantity#170 as decimal(10,0))), 0.00))#22 AS sumsales#171] (57) HashAggregate [codegen id : 70] -Input [2]: [i_category#167, sumsales#179] -Keys [1]: [i_category#167] -Functions [1]: [partial_sum(sumsales#179)] -Aggregate Attributes [2]: [sum#180, isEmpty#181] -Results [3]: [i_category#167, sum#182, isEmpty#183] +Input [2]: [i_category#159, sumsales#171] +Keys [1]: [i_category#159] +Functions [1]: [partial_sum(sumsales#171)] +Aggregate Attributes [2]: [sum#172, isEmpty#173] +Results [3]: [i_category#159, sum#174, isEmpty#175] (58) Exchange -Input [3]: [i_category#167, sum#182, isEmpty#183] -Arguments: hashpartitioning(i_category#167, 5), ENSURE_REQUIREMENTS, [plan_id=11] +Input [3]: [i_category#159, sum#174, isEmpty#175] +Arguments: hashpartitioning(i_category#159, 5), ENSURE_REQUIREMENTS, [plan_id=11] (59) HashAggregate [codegen id : 71] -Input [3]: [i_category#167, sum#182, isEmpty#183] -Keys [1]: [i_category#167] -Functions [1]: [sum(sumsales#179)] -Aggregate Attributes [1]: [sum(sumsales#179)#184] -Results [9]: [i_category#167, null AS i_class#185, null AS i_brand#186, null AS i_product_name#187, null AS d_year#188, null AS d_qoy#189, null AS d_moy#190, null AS s_store_id#191, sum(sumsales#179)#184 AS sumsales#192] +Input [3]: [i_category#159, sum#174, isEmpty#175] +Keys [1]: [i_category#159] +Functions [1]: [sum(sumsales#171)] +Aggregate Attributes [1]: [sum(sumsales#171)#176] +Results [9]: [i_category#159, null AS i_class#177, null AS i_brand#178, null AS i_product_name#179, null AS d_year#180, null AS d_qoy#181, null AS d_moy#182, null AS s_store_id#183, sum(sumsales#171)#176 AS sumsales#184] (60) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#193, i_class#194, i_brand#195, i_product_name#196, d_year#197, d_qoy#198, d_moy#199, s_store_id#200, sum#201, isEmpty#202] +Output [10]: [i_category#185, i_class#186, i_brand#187, i_product_name#188, d_year#189, d_qoy#190, d_moy#191, s_store_id#192, sum#193, isEmpty#194] (61) HashAggregate [codegen id : 79] -Input [10]: [i_category#193, i_class#194, i_brand#195, i_product_name#196, d_year#197, d_qoy#198, d_moy#199, s_store_id#200, sum#201, isEmpty#202] -Keys [8]: [i_category#193, i_class#194, i_brand#195, i_product_name#196, d_year#197, d_qoy#198, d_moy#199, s_store_id#200] -Functions [1]: [sum(coalesce((ss_sales_price#203 * cast(ss_quantity#204 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#203 * cast(ss_quantity#204 as decimal(10,0))), 0.00))#22] -Results [1]: [sum(coalesce((ss_sales_price#203 * cast(ss_quantity#204 as decimal(10,0))), 0.00))#22 AS sumsales#205] +Input [10]: [i_category#185, i_class#186, i_brand#187, i_product_name#188, d_year#189, d_qoy#190, d_moy#191, s_store_id#192, sum#193, isEmpty#194] +Keys [8]: [i_category#185, i_class#186, i_brand#187, i_product_name#188, d_year#189, d_qoy#190, d_moy#191, s_store_id#192] +Functions [1]: [sum(coalesce((ss_sales_price#195 * cast(ss_quantity#196 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#195 * cast(ss_quantity#196 as decimal(10,0))), 0.00))#22] +Results [1]: [sum(coalesce((ss_sales_price#195 * cast(ss_quantity#196 as decimal(10,0))), 0.00))#22 AS sumsales#197] (62) HashAggregate [codegen id : 79] -Input [1]: [sumsales#205] +Input [1]: [sumsales#197] Keys: [] -Functions [1]: [partial_sum(sumsales#205)] -Aggregate Attributes [2]: [sum#206, isEmpty#207] -Results [2]: [sum#208, isEmpty#209] +Functions [1]: [partial_sum(sumsales#197)] +Aggregate Attributes [2]: [sum#198, isEmpty#199] +Results [2]: [sum#200, isEmpty#201] (63) Exchange -Input [2]: [sum#208, isEmpty#209] +Input [2]: [sum#200, isEmpty#201] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [plan_id=12] (64) HashAggregate [codegen id : 80] -Input [2]: [sum#208, isEmpty#209] +Input [2]: [sum#200, isEmpty#201] Keys: [] -Functions [1]: [sum(sumsales#205)] -Aggregate Attributes [1]: [sum(sumsales#205)#210] -Results [9]: [null AS i_category#211, null AS i_class#212, null AS i_brand#213, null AS i_product_name#214, null AS d_year#215, null AS d_qoy#216, null AS d_moy#217, null AS s_store_id#218, sum(sumsales#205)#210 AS sumsales#219] +Functions [1]: [sum(sumsales#197)] +Aggregate Attributes [1]: [sum(sumsales#197)#202] +Results [9]: [null AS i_category#203, null AS i_class#204, null AS i_brand#205, null AS i_product_name#206, null AS d_year#207, null AS d_qoy#208, null AS d_moy#209, null AS s_store_id#210, sum(sumsales#197)#202 AS sumsales#211] (65) Union (66) Sort [codegen id : 81] -Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] -Arguments: [i_category#23 ASC NULLS FIRST, sumsales#31 DESC NULLS LAST], false, 0 +Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] +Arguments: [i_category#16 ASC NULLS FIRST, sumsales#23 DESC NULLS LAST], false, 0 (67) WindowGroupLimit -Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] -Arguments: [i_category#23], [sumsales#31 DESC NULLS LAST], rank(sumsales#31), 100, Partial +Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] +Arguments: [i_category#16], [sumsales#23 DESC NULLS LAST], rank(sumsales#23), 100, Partial (68) Exchange -Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] -Arguments: hashpartitioning(i_category#23, 5), ENSURE_REQUIREMENTS, [plan_id=13] +Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] +Arguments: hashpartitioning(i_category#16, 5), ENSURE_REQUIREMENTS, [plan_id=13] (69) Sort [codegen id : 82] -Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] -Arguments: [i_category#23 ASC NULLS FIRST, sumsales#31 DESC NULLS LAST], false, 0 +Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] +Arguments: [i_category#16 ASC NULLS FIRST, sumsales#23 DESC NULLS LAST], false, 0 (70) WindowGroupLimit -Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] -Arguments: [i_category#23], [sumsales#31 DESC NULLS LAST], rank(sumsales#31), 100, Final +Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] +Arguments: [i_category#16], [sumsales#23 DESC NULLS LAST], rank(sumsales#23), 100, Final (71) Window -Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] -Arguments: [rank(sumsales#31) windowspecdefinition(i_category#23, sumsales#31 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#220], [i_category#23], [sumsales#31 DESC NULLS LAST] +Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] +Arguments: [rank(sumsales#23) windowspecdefinition(i_category#16, sumsales#23 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#212], [i_category#16], [sumsales#23 DESC NULLS LAST] (72) Filter [codegen id : 83] -Input [10]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31, rk#220] -Condition : (rk#220 <= 100) +Input [10]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23, rk#212] +Condition : (rk#212 <= 100) (73) TakeOrderedAndProject -Input [10]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31, rk#220] -Arguments: 100, [i_category#23 ASC NULLS FIRST, i_class#24 ASC NULLS FIRST, i_brand#25 ASC NULLS FIRST, i_product_name#26 ASC NULLS FIRST, d_year#27 ASC NULLS FIRST, d_qoy#28 ASC NULLS FIRST, d_moy#29 ASC NULLS FIRST, s_store_id#30 ASC NULLS FIRST, sumsales#31 ASC NULLS FIRST, rk#220 ASC NULLS FIRST], [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31, rk#220] +Input [10]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23, rk#212] +Arguments: 100, [i_category#16 ASC NULLS FIRST, i_class#15 ASC NULLS FIRST, i_brand#14 ASC NULLS FIRST, i_product_name#17 ASC NULLS FIRST, d_year#8 ASC NULLS FIRST, d_qoy#10 ASC NULLS FIRST, d_moy#9 ASC NULLS FIRST, s_store_id#12 ASC NULLS FIRST, sumsales#23 ASC NULLS FIRST, rk#212 ASC NULLS FIRST], [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23, rk#212] ===== Subqueries ===== @@ -457,22 +457,22 @@ BroadcastExchange (78) (74) Scan parquet spark_catalog.default.date_dim -Output [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] +Output [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,1212), LessThanOrEqual(d_month_seq,1223), IsNotNull(d_date_sk)] ReadSchema: struct (75) ColumnarToRow [codegen id : 1] -Input [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] +Input [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] (76) Filter [codegen id : 1] -Input [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] -Condition : (((isnotnull(d_month_seq#221) AND (d_month_seq#221 >= 1212)) AND (d_month_seq#221 <= 1223)) AND isnotnull(d_date_sk#7)) +Input [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] +Condition : (((isnotnull(d_month_seq#213) AND (d_month_seq#213 >= 1212)) AND (d_month_seq#213 <= 1223)) AND isnotnull(d_date_sk#7)) (77) Project [codegen id : 1] Output [4]: [d_date_sk#7, d_year#8, d_moy#9, d_qoy#10] -Input [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] +Input [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] (78) BroadcastExchange Input [4]: [d_date_sk#7, d_year#8, d_moy#9, d_qoy#10] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/simplified.txt index b6a4358c4d43b..795fa297b9bad 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/simplified.txt @@ -14,7 +14,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ InputAdapter Union WholeStageCodegen (8) - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce((ss_sales_price * cast(ss_quantity as decimal(10,0))), 0.00)),i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce((ss_sales_price * cast(ss_quantity as decimal(10,0))), 0.00)),sumsales,sum,isEmpty] InputAdapter Exchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id] #2 WholeStageCodegen (7) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/explain.txt index 417af4fe924ee..75d526da4ba71 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/explain.txt @@ -171,265 +171,265 @@ Input [10]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, Keys [8]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] Functions [1]: [sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))] Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))#22] -Results [9]: [i_category#16 AS i_category#23, i_class#15 AS i_class#24, i_brand#14 AS i_brand#25, i_product_name#17 AS i_product_name#26, d_year#8 AS d_year#27, d_qoy#10 AS d_qoy#28, d_moy#9 AS d_moy#29, s_store_id#12 AS s_store_id#30, cast(sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))#22 as decimal(38,2)) AS sumsales#31] +Results [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, cast(sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))#22 as decimal(38,2)) AS sumsales#23] (22) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, s_store_id#39, sum#40, isEmpty#41] +Output [10]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, s_store_id#31, sum#32, isEmpty#33] (23) HashAggregate [codegen id : 10] -Input [10]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, s_store_id#39, sum#40, isEmpty#41] -Keys [8]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, s_store_id#39] -Functions [1]: [sum(coalesce((ss_sales_price#42 * cast(ss_quantity#43 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#42 * cast(ss_quantity#43 as decimal(10,0))), 0.00))#22] -Results [8]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum(coalesce((ss_sales_price#42 * cast(ss_quantity#43 as decimal(10,0))), 0.00))#22 AS sumsales#44] +Input [10]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, s_store_id#31, sum#32, isEmpty#33] +Keys [8]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, s_store_id#31] +Functions [1]: [sum(coalesce((ss_sales_price#34 * cast(ss_quantity#35 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#34 * cast(ss_quantity#35 as decimal(10,0))), 0.00))#22] +Results [8]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum(coalesce((ss_sales_price#34 * cast(ss_quantity#35 as decimal(10,0))), 0.00))#22 AS sumsales#36] (24) HashAggregate [codegen id : 10] -Input [8]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sumsales#44] -Keys [7]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38] -Functions [1]: [partial_sum(sumsales#44)] -Aggregate Attributes [2]: [sum#45, isEmpty#46] -Results [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum#47, isEmpty#48] +Input [8]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sumsales#36] +Keys [7]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30] +Functions [1]: [partial_sum(sumsales#36)] +Aggregate Attributes [2]: [sum#37, isEmpty#38] +Results [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum#39, isEmpty#40] (25) Exchange -Input [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum#47, isEmpty#48] -Arguments: hashpartitioning(i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, 5), ENSURE_REQUIREMENTS, [plan_id=4] +Input [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum#39, isEmpty#40] +Arguments: hashpartitioning(i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, 5), ENSURE_REQUIREMENTS, [plan_id=4] (26) HashAggregate [codegen id : 11] -Input [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum#47, isEmpty#48] -Keys [7]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38] -Functions [1]: [sum(sumsales#44)] -Aggregate Attributes [1]: [sum(sumsales#44)#49] -Results [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, null AS s_store_id#50, sum(sumsales#44)#49 AS sumsales#51] +Input [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum#39, isEmpty#40] +Keys [7]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30] +Functions [1]: [sum(sumsales#36)] +Aggregate Attributes [1]: [sum(sumsales#36)#41] +Results [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, null AS s_store_id#42, sum(sumsales#36)#41 AS sumsales#43] (27) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, d_moy#58, s_store_id#59, sum#60, isEmpty#61] +Output [10]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, d_moy#50, s_store_id#51, sum#52, isEmpty#53] (28) HashAggregate [codegen id : 16] -Input [10]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, d_moy#58, s_store_id#59, sum#60, isEmpty#61] -Keys [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, d_moy#58, s_store_id#59] -Functions [1]: [sum(coalesce((ss_sales_price#62 * cast(ss_quantity#63 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#62 * cast(ss_quantity#63 as decimal(10,0))), 0.00))#22] -Results [7]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum(coalesce((ss_sales_price#62 * cast(ss_quantity#63 as decimal(10,0))), 0.00))#22 AS sumsales#64] +Input [10]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, d_moy#50, s_store_id#51, sum#52, isEmpty#53] +Keys [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, d_moy#50, s_store_id#51] +Functions [1]: [sum(coalesce((ss_sales_price#54 * cast(ss_quantity#55 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#54 * cast(ss_quantity#55 as decimal(10,0))), 0.00))#22] +Results [7]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum(coalesce((ss_sales_price#54 * cast(ss_quantity#55 as decimal(10,0))), 0.00))#22 AS sumsales#56] (29) HashAggregate [codegen id : 16] -Input [7]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sumsales#64] -Keys [6]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57] -Functions [1]: [partial_sum(sumsales#64)] -Aggregate Attributes [2]: [sum#65, isEmpty#66] -Results [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum#67, isEmpty#68] +Input [7]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sumsales#56] +Keys [6]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49] +Functions [1]: [partial_sum(sumsales#56)] +Aggregate Attributes [2]: [sum#57, isEmpty#58] +Results [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum#59, isEmpty#60] (30) Exchange -Input [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum#67, isEmpty#68] -Arguments: hashpartitioning(i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, 5), ENSURE_REQUIREMENTS, [plan_id=5] +Input [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum#59, isEmpty#60] +Arguments: hashpartitioning(i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, 5), ENSURE_REQUIREMENTS, [plan_id=5] (31) HashAggregate [codegen id : 17] -Input [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum#67, isEmpty#68] -Keys [6]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57] -Functions [1]: [sum(sumsales#64)] -Aggregate Attributes [1]: [sum(sumsales#64)#69] -Results [9]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, null AS d_moy#70, null AS s_store_id#71, sum(sumsales#64)#69 AS sumsales#72] +Input [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum#59, isEmpty#60] +Keys [6]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49] +Functions [1]: [sum(sumsales#56)] +Aggregate Attributes [1]: [sum(sumsales#56)#61] +Results [9]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, null AS d_moy#62, null AS s_store_id#63, sum(sumsales#56)#61 AS sumsales#64] (32) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, d_qoy#78, d_moy#79, s_store_id#80, sum#81, isEmpty#82] +Output [10]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, d_qoy#70, d_moy#71, s_store_id#72, sum#73, isEmpty#74] (33) HashAggregate [codegen id : 22] -Input [10]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, d_qoy#78, d_moy#79, s_store_id#80, sum#81, isEmpty#82] -Keys [8]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, d_qoy#78, d_moy#79, s_store_id#80] -Functions [1]: [sum(coalesce((ss_sales_price#83 * cast(ss_quantity#84 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#83 * cast(ss_quantity#84 as decimal(10,0))), 0.00))#22] -Results [6]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum(coalesce((ss_sales_price#83 * cast(ss_quantity#84 as decimal(10,0))), 0.00))#22 AS sumsales#85] +Input [10]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, d_qoy#70, d_moy#71, s_store_id#72, sum#73, isEmpty#74] +Keys [8]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, d_qoy#70, d_moy#71, s_store_id#72] +Functions [1]: [sum(coalesce((ss_sales_price#75 * cast(ss_quantity#76 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#75 * cast(ss_quantity#76 as decimal(10,0))), 0.00))#22] +Results [6]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum(coalesce((ss_sales_price#75 * cast(ss_quantity#76 as decimal(10,0))), 0.00))#22 AS sumsales#77] (34) HashAggregate [codegen id : 22] -Input [6]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sumsales#85] -Keys [5]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77] -Functions [1]: [partial_sum(sumsales#85)] -Aggregate Attributes [2]: [sum#86, isEmpty#87] -Results [7]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum#88, isEmpty#89] +Input [6]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sumsales#77] +Keys [5]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69] +Functions [1]: [partial_sum(sumsales#77)] +Aggregate Attributes [2]: [sum#78, isEmpty#79] +Results [7]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum#80, isEmpty#81] (35) Exchange -Input [7]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum#88, isEmpty#89] -Arguments: hashpartitioning(i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, 5), ENSURE_REQUIREMENTS, [plan_id=6] +Input [7]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum#80, isEmpty#81] +Arguments: hashpartitioning(i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, 5), ENSURE_REQUIREMENTS, [plan_id=6] (36) HashAggregate [codegen id : 23] -Input [7]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum#88, isEmpty#89] -Keys [5]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77] -Functions [1]: [sum(sumsales#85)] -Aggregate Attributes [1]: [sum(sumsales#85)#90] -Results [9]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, null AS d_qoy#91, null AS d_moy#92, null AS s_store_id#93, sum(sumsales#85)#90 AS sumsales#94] +Input [7]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum#80, isEmpty#81] +Keys [5]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69] +Functions [1]: [sum(sumsales#77)] +Aggregate Attributes [1]: [sum(sumsales#77)#82] +Results [9]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, null AS d_qoy#83, null AS d_moy#84, null AS s_store_id#85, sum(sumsales#77)#82 AS sumsales#86] (37) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, d_year#99, d_qoy#100, d_moy#101, s_store_id#102, sum#103, isEmpty#104] +Output [10]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, d_year#91, d_qoy#92, d_moy#93, s_store_id#94, sum#95, isEmpty#96] (38) HashAggregate [codegen id : 28] -Input [10]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, d_year#99, d_qoy#100, d_moy#101, s_store_id#102, sum#103, isEmpty#104] -Keys [8]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, d_year#99, d_qoy#100, d_moy#101, s_store_id#102] -Functions [1]: [sum(coalesce((ss_sales_price#105 * cast(ss_quantity#106 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#105 * cast(ss_quantity#106 as decimal(10,0))), 0.00))#22] -Results [5]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum(coalesce((ss_sales_price#105 * cast(ss_quantity#106 as decimal(10,0))), 0.00))#22 AS sumsales#107] +Input [10]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, d_year#91, d_qoy#92, d_moy#93, s_store_id#94, sum#95, isEmpty#96] +Keys [8]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, d_year#91, d_qoy#92, d_moy#93, s_store_id#94] +Functions [1]: [sum(coalesce((ss_sales_price#97 * cast(ss_quantity#98 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#97 * cast(ss_quantity#98 as decimal(10,0))), 0.00))#22] +Results [5]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum(coalesce((ss_sales_price#97 * cast(ss_quantity#98 as decimal(10,0))), 0.00))#22 AS sumsales#99] (39) HashAggregate [codegen id : 28] -Input [5]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sumsales#107] -Keys [4]: [i_category#95, i_class#96, i_brand#97, i_product_name#98] -Functions [1]: [partial_sum(sumsales#107)] -Aggregate Attributes [2]: [sum#108, isEmpty#109] -Results [6]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum#110, isEmpty#111] +Input [5]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sumsales#99] +Keys [4]: [i_category#87, i_class#88, i_brand#89, i_product_name#90] +Functions [1]: [partial_sum(sumsales#99)] +Aggregate Attributes [2]: [sum#100, isEmpty#101] +Results [6]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum#102, isEmpty#103] (40) Exchange -Input [6]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum#110, isEmpty#111] -Arguments: hashpartitioning(i_category#95, i_class#96, i_brand#97, i_product_name#98, 5), ENSURE_REQUIREMENTS, [plan_id=7] +Input [6]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum#102, isEmpty#103] +Arguments: hashpartitioning(i_category#87, i_class#88, i_brand#89, i_product_name#90, 5), ENSURE_REQUIREMENTS, [plan_id=7] (41) HashAggregate [codegen id : 29] -Input [6]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum#110, isEmpty#111] -Keys [4]: [i_category#95, i_class#96, i_brand#97, i_product_name#98] -Functions [1]: [sum(sumsales#107)] -Aggregate Attributes [1]: [sum(sumsales#107)#112] -Results [9]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, null AS d_year#113, null AS d_qoy#114, null AS d_moy#115, null AS s_store_id#116, sum(sumsales#107)#112 AS sumsales#117] +Input [6]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum#102, isEmpty#103] +Keys [4]: [i_category#87, i_class#88, i_brand#89, i_product_name#90] +Functions [1]: [sum(sumsales#99)] +Aggregate Attributes [1]: [sum(sumsales#99)#104] +Results [9]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, null AS d_year#105, null AS d_qoy#106, null AS d_moy#107, null AS s_store_id#108, sum(sumsales#99)#104 AS sumsales#109] (42) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#118, i_class#119, i_brand#120, i_product_name#121, d_year#122, d_qoy#123, d_moy#124, s_store_id#125, sum#126, isEmpty#127] +Output [10]: [i_category#110, i_class#111, i_brand#112, i_product_name#113, d_year#114, d_qoy#115, d_moy#116, s_store_id#117, sum#118, isEmpty#119] (43) HashAggregate [codegen id : 34] -Input [10]: [i_category#118, i_class#119, i_brand#120, i_product_name#121, d_year#122, d_qoy#123, d_moy#124, s_store_id#125, sum#126, isEmpty#127] -Keys [8]: [i_category#118, i_class#119, i_brand#120, i_product_name#121, d_year#122, d_qoy#123, d_moy#124, s_store_id#125] -Functions [1]: [sum(coalesce((ss_sales_price#128 * cast(ss_quantity#129 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#128 * cast(ss_quantity#129 as decimal(10,0))), 0.00))#22] -Results [4]: [i_category#118, i_class#119, i_brand#120, sum(coalesce((ss_sales_price#128 * cast(ss_quantity#129 as decimal(10,0))), 0.00))#22 AS sumsales#130] +Input [10]: [i_category#110, i_class#111, i_brand#112, i_product_name#113, d_year#114, d_qoy#115, d_moy#116, s_store_id#117, sum#118, isEmpty#119] +Keys [8]: [i_category#110, i_class#111, i_brand#112, i_product_name#113, d_year#114, d_qoy#115, d_moy#116, s_store_id#117] +Functions [1]: [sum(coalesce((ss_sales_price#120 * cast(ss_quantity#121 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#120 * cast(ss_quantity#121 as decimal(10,0))), 0.00))#22] +Results [4]: [i_category#110, i_class#111, i_brand#112, sum(coalesce((ss_sales_price#120 * cast(ss_quantity#121 as decimal(10,0))), 0.00))#22 AS sumsales#122] (44) HashAggregate [codegen id : 34] -Input [4]: [i_category#118, i_class#119, i_brand#120, sumsales#130] -Keys [3]: [i_category#118, i_class#119, i_brand#120] -Functions [1]: [partial_sum(sumsales#130)] -Aggregate Attributes [2]: [sum#131, isEmpty#132] -Results [5]: [i_category#118, i_class#119, i_brand#120, sum#133, isEmpty#134] +Input [4]: [i_category#110, i_class#111, i_brand#112, sumsales#122] +Keys [3]: [i_category#110, i_class#111, i_brand#112] +Functions [1]: [partial_sum(sumsales#122)] +Aggregate Attributes [2]: [sum#123, isEmpty#124] +Results [5]: [i_category#110, i_class#111, i_brand#112, sum#125, isEmpty#126] (45) Exchange -Input [5]: [i_category#118, i_class#119, i_brand#120, sum#133, isEmpty#134] -Arguments: hashpartitioning(i_category#118, i_class#119, i_brand#120, 5), ENSURE_REQUIREMENTS, [plan_id=8] +Input [5]: [i_category#110, i_class#111, i_brand#112, sum#125, isEmpty#126] +Arguments: hashpartitioning(i_category#110, i_class#111, i_brand#112, 5), ENSURE_REQUIREMENTS, [plan_id=8] (46) HashAggregate [codegen id : 35] -Input [5]: [i_category#118, i_class#119, i_brand#120, sum#133, isEmpty#134] -Keys [3]: [i_category#118, i_class#119, i_brand#120] -Functions [1]: [sum(sumsales#130)] -Aggregate Attributes [1]: [sum(sumsales#130)#135] -Results [9]: [i_category#118, i_class#119, i_brand#120, null AS i_product_name#136, null AS d_year#137, null AS d_qoy#138, null AS d_moy#139, null AS s_store_id#140, sum(sumsales#130)#135 AS sumsales#141] +Input [5]: [i_category#110, i_class#111, i_brand#112, sum#125, isEmpty#126] +Keys [3]: [i_category#110, i_class#111, i_brand#112] +Functions [1]: [sum(sumsales#122)] +Aggregate Attributes [1]: [sum(sumsales#122)#127] +Results [9]: [i_category#110, i_class#111, i_brand#112, null AS i_product_name#128, null AS d_year#129, null AS d_qoy#130, null AS d_moy#131, null AS s_store_id#132, sum(sumsales#122)#127 AS sumsales#133] (47) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#142, i_class#143, i_brand#144, i_product_name#145, d_year#146, d_qoy#147, d_moy#148, s_store_id#149, sum#150, isEmpty#151] +Output [10]: [i_category#134, i_class#135, i_brand#136, i_product_name#137, d_year#138, d_qoy#139, d_moy#140, s_store_id#141, sum#142, isEmpty#143] (48) HashAggregate [codegen id : 40] -Input [10]: [i_category#142, i_class#143, i_brand#144, i_product_name#145, d_year#146, d_qoy#147, d_moy#148, s_store_id#149, sum#150, isEmpty#151] -Keys [8]: [i_category#142, i_class#143, i_brand#144, i_product_name#145, d_year#146, d_qoy#147, d_moy#148, s_store_id#149] -Functions [1]: [sum(coalesce((ss_sales_price#152 * cast(ss_quantity#153 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#152 * cast(ss_quantity#153 as decimal(10,0))), 0.00))#22] -Results [3]: [i_category#142, i_class#143, sum(coalesce((ss_sales_price#152 * cast(ss_quantity#153 as decimal(10,0))), 0.00))#22 AS sumsales#154] +Input [10]: [i_category#134, i_class#135, i_brand#136, i_product_name#137, d_year#138, d_qoy#139, d_moy#140, s_store_id#141, sum#142, isEmpty#143] +Keys [8]: [i_category#134, i_class#135, i_brand#136, i_product_name#137, d_year#138, d_qoy#139, d_moy#140, s_store_id#141] +Functions [1]: [sum(coalesce((ss_sales_price#144 * cast(ss_quantity#145 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#144 * cast(ss_quantity#145 as decimal(10,0))), 0.00))#22] +Results [3]: [i_category#134, i_class#135, sum(coalesce((ss_sales_price#144 * cast(ss_quantity#145 as decimal(10,0))), 0.00))#22 AS sumsales#146] (49) HashAggregate [codegen id : 40] -Input [3]: [i_category#142, i_class#143, sumsales#154] -Keys [2]: [i_category#142, i_class#143] -Functions [1]: [partial_sum(sumsales#154)] -Aggregate Attributes [2]: [sum#155, isEmpty#156] -Results [4]: [i_category#142, i_class#143, sum#157, isEmpty#158] +Input [3]: [i_category#134, i_class#135, sumsales#146] +Keys [2]: [i_category#134, i_class#135] +Functions [1]: [partial_sum(sumsales#146)] +Aggregate Attributes [2]: [sum#147, isEmpty#148] +Results [4]: [i_category#134, i_class#135, sum#149, isEmpty#150] (50) Exchange -Input [4]: [i_category#142, i_class#143, sum#157, isEmpty#158] -Arguments: hashpartitioning(i_category#142, i_class#143, 5), ENSURE_REQUIREMENTS, [plan_id=9] +Input [4]: [i_category#134, i_class#135, sum#149, isEmpty#150] +Arguments: hashpartitioning(i_category#134, i_class#135, 5), ENSURE_REQUIREMENTS, [plan_id=9] (51) HashAggregate [codegen id : 41] -Input [4]: [i_category#142, i_class#143, sum#157, isEmpty#158] -Keys [2]: [i_category#142, i_class#143] -Functions [1]: [sum(sumsales#154)] -Aggregate Attributes [1]: [sum(sumsales#154)#159] -Results [9]: [i_category#142, i_class#143, null AS i_brand#160, null AS i_product_name#161, null AS d_year#162, null AS d_qoy#163, null AS d_moy#164, null AS s_store_id#165, sum(sumsales#154)#159 AS sumsales#166] +Input [4]: [i_category#134, i_class#135, sum#149, isEmpty#150] +Keys [2]: [i_category#134, i_class#135] +Functions [1]: [sum(sumsales#146)] +Aggregate Attributes [1]: [sum(sumsales#146)#151] +Results [9]: [i_category#134, i_class#135, null AS i_brand#152, null AS i_product_name#153, null AS d_year#154, null AS d_qoy#155, null AS d_moy#156, null AS s_store_id#157, sum(sumsales#146)#151 AS sumsales#158] (52) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#167, i_class#168, i_brand#169, i_product_name#170, d_year#171, d_qoy#172, d_moy#173, s_store_id#174, sum#175, isEmpty#176] +Output [10]: [i_category#159, i_class#160, i_brand#161, i_product_name#162, d_year#163, d_qoy#164, d_moy#165, s_store_id#166, sum#167, isEmpty#168] (53) HashAggregate [codegen id : 46] -Input [10]: [i_category#167, i_class#168, i_brand#169, i_product_name#170, d_year#171, d_qoy#172, d_moy#173, s_store_id#174, sum#175, isEmpty#176] -Keys [8]: [i_category#167, i_class#168, i_brand#169, i_product_name#170, d_year#171, d_qoy#172, d_moy#173, s_store_id#174] -Functions [1]: [sum(coalesce((ss_sales_price#177 * cast(ss_quantity#178 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#177 * cast(ss_quantity#178 as decimal(10,0))), 0.00))#22] -Results [2]: [i_category#167, sum(coalesce((ss_sales_price#177 * cast(ss_quantity#178 as decimal(10,0))), 0.00))#22 AS sumsales#179] +Input [10]: [i_category#159, i_class#160, i_brand#161, i_product_name#162, d_year#163, d_qoy#164, d_moy#165, s_store_id#166, sum#167, isEmpty#168] +Keys [8]: [i_category#159, i_class#160, i_brand#161, i_product_name#162, d_year#163, d_qoy#164, d_moy#165, s_store_id#166] +Functions [1]: [sum(coalesce((ss_sales_price#169 * cast(ss_quantity#170 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#169 * cast(ss_quantity#170 as decimal(10,0))), 0.00))#22] +Results [2]: [i_category#159, sum(coalesce((ss_sales_price#169 * cast(ss_quantity#170 as decimal(10,0))), 0.00))#22 AS sumsales#171] (54) HashAggregate [codegen id : 46] -Input [2]: [i_category#167, sumsales#179] -Keys [1]: [i_category#167] -Functions [1]: [partial_sum(sumsales#179)] -Aggregate Attributes [2]: [sum#180, isEmpty#181] -Results [3]: [i_category#167, sum#182, isEmpty#183] +Input [2]: [i_category#159, sumsales#171] +Keys [1]: [i_category#159] +Functions [1]: [partial_sum(sumsales#171)] +Aggregate Attributes [2]: [sum#172, isEmpty#173] +Results [3]: [i_category#159, sum#174, isEmpty#175] (55) Exchange -Input [3]: [i_category#167, sum#182, isEmpty#183] -Arguments: hashpartitioning(i_category#167, 5), ENSURE_REQUIREMENTS, [plan_id=10] +Input [3]: [i_category#159, sum#174, isEmpty#175] +Arguments: hashpartitioning(i_category#159, 5), ENSURE_REQUIREMENTS, [plan_id=10] (56) HashAggregate [codegen id : 47] -Input [3]: [i_category#167, sum#182, isEmpty#183] -Keys [1]: [i_category#167] -Functions [1]: [sum(sumsales#179)] -Aggregate Attributes [1]: [sum(sumsales#179)#184] -Results [9]: [i_category#167, null AS i_class#185, null AS i_brand#186, null AS i_product_name#187, null AS d_year#188, null AS d_qoy#189, null AS d_moy#190, null AS s_store_id#191, sum(sumsales#179)#184 AS sumsales#192] +Input [3]: [i_category#159, sum#174, isEmpty#175] +Keys [1]: [i_category#159] +Functions [1]: [sum(sumsales#171)] +Aggregate Attributes [1]: [sum(sumsales#171)#176] +Results [9]: [i_category#159, null AS i_class#177, null AS i_brand#178, null AS i_product_name#179, null AS d_year#180, null AS d_qoy#181, null AS d_moy#182, null AS s_store_id#183, sum(sumsales#171)#176 AS sumsales#184] (57) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#193, i_class#194, i_brand#195, i_product_name#196, d_year#197, d_qoy#198, d_moy#199, s_store_id#200, sum#201, isEmpty#202] +Output [10]: [i_category#185, i_class#186, i_brand#187, i_product_name#188, d_year#189, d_qoy#190, d_moy#191, s_store_id#192, sum#193, isEmpty#194] (58) HashAggregate [codegen id : 52] -Input [10]: [i_category#193, i_class#194, i_brand#195, i_product_name#196, d_year#197, d_qoy#198, d_moy#199, s_store_id#200, sum#201, isEmpty#202] -Keys [8]: [i_category#193, i_class#194, i_brand#195, i_product_name#196, d_year#197, d_qoy#198, d_moy#199, s_store_id#200] -Functions [1]: [sum(coalesce((ss_sales_price#203 * cast(ss_quantity#204 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#203 * cast(ss_quantity#204 as decimal(10,0))), 0.00))#22] -Results [1]: [sum(coalesce((ss_sales_price#203 * cast(ss_quantity#204 as decimal(10,0))), 0.00))#22 AS sumsales#205] +Input [10]: [i_category#185, i_class#186, i_brand#187, i_product_name#188, d_year#189, d_qoy#190, d_moy#191, s_store_id#192, sum#193, isEmpty#194] +Keys [8]: [i_category#185, i_class#186, i_brand#187, i_product_name#188, d_year#189, d_qoy#190, d_moy#191, s_store_id#192] +Functions [1]: [sum(coalesce((ss_sales_price#195 * cast(ss_quantity#196 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#195 * cast(ss_quantity#196 as decimal(10,0))), 0.00))#22] +Results [1]: [sum(coalesce((ss_sales_price#195 * cast(ss_quantity#196 as decimal(10,0))), 0.00))#22 AS sumsales#197] (59) HashAggregate [codegen id : 52] -Input [1]: [sumsales#205] +Input [1]: [sumsales#197] Keys: [] -Functions [1]: [partial_sum(sumsales#205)] -Aggregate Attributes [2]: [sum#206, isEmpty#207] -Results [2]: [sum#208, isEmpty#209] +Functions [1]: [partial_sum(sumsales#197)] +Aggregate Attributes [2]: [sum#198, isEmpty#199] +Results [2]: [sum#200, isEmpty#201] (60) Exchange -Input [2]: [sum#208, isEmpty#209] +Input [2]: [sum#200, isEmpty#201] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [plan_id=11] (61) HashAggregate [codegen id : 53] -Input [2]: [sum#208, isEmpty#209] +Input [2]: [sum#200, isEmpty#201] Keys: [] -Functions [1]: [sum(sumsales#205)] -Aggregate Attributes [1]: [sum(sumsales#205)#210] -Results [9]: [null AS i_category#211, null AS i_class#212, null AS i_brand#213, null AS i_product_name#214, null AS d_year#215, null AS d_qoy#216, null AS d_moy#217, null AS s_store_id#218, sum(sumsales#205)#210 AS sumsales#219] +Functions [1]: [sum(sumsales#197)] +Aggregate Attributes [1]: [sum(sumsales#197)#202] +Results [9]: [null AS i_category#203, null AS i_class#204, null AS i_brand#205, null AS i_product_name#206, null AS d_year#207, null AS d_qoy#208, null AS d_moy#209, null AS s_store_id#210, sum(sumsales#197)#202 AS sumsales#211] (62) Union (63) Sort [codegen id : 54] -Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] -Arguments: [i_category#23 ASC NULLS FIRST, sumsales#31 DESC NULLS LAST], false, 0 +Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] +Arguments: [i_category#16 ASC NULLS FIRST, sumsales#23 DESC NULLS LAST], false, 0 (64) WindowGroupLimit -Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] -Arguments: [i_category#23], [sumsales#31 DESC NULLS LAST], rank(sumsales#31), 100, Partial +Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] +Arguments: [i_category#16], [sumsales#23 DESC NULLS LAST], rank(sumsales#23), 100, Partial (65) Exchange -Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] -Arguments: hashpartitioning(i_category#23, 5), ENSURE_REQUIREMENTS, [plan_id=12] +Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] +Arguments: hashpartitioning(i_category#16, 5), ENSURE_REQUIREMENTS, [plan_id=12] (66) Sort [codegen id : 55] -Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] -Arguments: [i_category#23 ASC NULLS FIRST, sumsales#31 DESC NULLS LAST], false, 0 +Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] +Arguments: [i_category#16 ASC NULLS FIRST, sumsales#23 DESC NULLS LAST], false, 0 (67) WindowGroupLimit -Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] -Arguments: [i_category#23], [sumsales#31 DESC NULLS LAST], rank(sumsales#31), 100, Final +Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] +Arguments: [i_category#16], [sumsales#23 DESC NULLS LAST], rank(sumsales#23), 100, Final (68) Window -Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] -Arguments: [rank(sumsales#31) windowspecdefinition(i_category#23, sumsales#31 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#220], [i_category#23], [sumsales#31 DESC NULLS LAST] +Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] +Arguments: [rank(sumsales#23) windowspecdefinition(i_category#16, sumsales#23 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#212], [i_category#16], [sumsales#23 DESC NULLS LAST] (69) Filter [codegen id : 56] -Input [10]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31, rk#220] -Condition : (rk#220 <= 100) +Input [10]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23, rk#212] +Condition : (rk#212 <= 100) (70) TakeOrderedAndProject -Input [10]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31, rk#220] -Arguments: 100, [i_category#23 ASC NULLS FIRST, i_class#24 ASC NULLS FIRST, i_brand#25 ASC NULLS FIRST, i_product_name#26 ASC NULLS FIRST, d_year#27 ASC NULLS FIRST, d_qoy#28 ASC NULLS FIRST, d_moy#29 ASC NULLS FIRST, s_store_id#30 ASC NULLS FIRST, sumsales#31 ASC NULLS FIRST, rk#220 ASC NULLS FIRST], [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31, rk#220] +Input [10]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23, rk#212] +Arguments: 100, [i_category#16 ASC NULLS FIRST, i_class#15 ASC NULLS FIRST, i_brand#14 ASC NULLS FIRST, i_product_name#17 ASC NULLS FIRST, d_year#8 ASC NULLS FIRST, d_qoy#10 ASC NULLS FIRST, d_moy#9 ASC NULLS FIRST, s_store_id#12 ASC NULLS FIRST, sumsales#23 ASC NULLS FIRST, rk#212 ASC NULLS FIRST], [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23, rk#212] ===== Subqueries ===== @@ -442,22 +442,22 @@ BroadcastExchange (75) (71) Scan parquet spark_catalog.default.date_dim -Output [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] +Output [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,1212), LessThanOrEqual(d_month_seq,1223), IsNotNull(d_date_sk)] ReadSchema: struct (72) ColumnarToRow [codegen id : 1] -Input [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] +Input [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] (73) Filter [codegen id : 1] -Input [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] -Condition : (((isnotnull(d_month_seq#221) AND (d_month_seq#221 >= 1212)) AND (d_month_seq#221 <= 1223)) AND isnotnull(d_date_sk#7)) +Input [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] +Condition : (((isnotnull(d_month_seq#213) AND (d_month_seq#213 >= 1212)) AND (d_month_seq#213 <= 1223)) AND isnotnull(d_date_sk#7)) (74) Project [codegen id : 1] Output [4]: [d_date_sk#7, d_year#8, d_moy#9, d_qoy#10] -Input [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] +Input [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] (75) BroadcastExchange Input [4]: [d_date_sk#7, d_year#8, d_moy#9, d_qoy#10] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/simplified.txt index 5a43dced056bd..89393f265a49f 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/simplified.txt @@ -14,7 +14,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ InputAdapter Union WholeStageCodegen (5) - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce((ss_sales_price * cast(ss_quantity as decimal(10,0))), 0.00)),i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce((ss_sales_price * cast(ss_quantity as decimal(10,0))), 0.00)),sumsales,sum,isEmpty] InputAdapter Exchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id] #2 WholeStageCodegen (4) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index 273e8e08fd7a5..3b987529afcb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -349,7 +349,7 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession |FROM VALUES (0), (1), (2), (10) AS tab(col); |""".stripMargin).collect() }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL", + condition = "DATATYPE_MISMATCH.UNEXPECTED_NULL", parameters = Map( "exprName" -> "accuracy", "sqlExpr" -> "\"percentile_approx(col, array(0.5, 0.4, 0.1), NULL)\""), @@ -363,7 +363,7 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession |FROM VALUES (0), (1), (2), (10) AS tab(col); |""".stripMargin).collect() }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL", + condition = "DATATYPE_MISMATCH.UNEXPECTED_NULL", parameters = Map( "exprName" -> "percentage", "sqlExpr" -> "\"percentile_approx(col, NULL, 100)\""), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BitmapExpressionsQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/BitmapExpressionsQuerySuite.scala index 0778599d54f49..97814e3bac44b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/BitmapExpressionsQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/BitmapExpressionsQuerySuite.scala @@ -214,7 +214,7 @@ class BitmapExpressionsQuerySuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("bitmap_count(a)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"bitmap_count(a)\"", "paramIndex" -> "first", @@ -236,7 +236,7 @@ class BitmapExpressionsQuerySuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("bitmap_or_agg(a)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"bitmap_or_agg(a)\"", "paramIndex" -> "first", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala index 9b39a2295e7d6..af97856fd222e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala @@ -98,7 +98,7 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { } checkError( exception = exception, - errorClass = "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE", + condition = "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE", parameters = Map( "exprName" -> "estimatedNumItems", "valueRange" -> "[0, positive]", @@ -126,7 +126,7 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { } checkError( exception = exception, - errorClass = "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE", + condition = "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE", parameters = Map( "exprName" -> "numBits", "valueRange" -> "[0, positive]", @@ -159,7 +159,7 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { } checkError( exception = exception1, - errorClass = "DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE", + condition = "DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE", parameters = Map( "functionName" -> "`bloom_filter_agg`", "sqlExpr" -> "\"bloom_filter_agg(a, 1000000, 8388608)\"", @@ -182,7 +182,7 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { } checkError( exception = exception2, - errorClass = "DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE", + condition = "DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE", parameters = Map( "functionName" -> "`bloom_filter_agg`", "sqlExpr" -> "\"bloom_filter_agg(a, 2, (2 * 8))\"", @@ -205,7 +205,7 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { } checkError( exception = exception3, - errorClass = "DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE", + condition = "DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE", parameters = Map( "functionName" -> "`bloom_filter_agg`", "sqlExpr" -> "\"bloom_filter_agg(a, CAST(2 AS BIGINT), 5)\"", @@ -228,7 +228,7 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { } checkError( exception = exception4, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL", + condition = "DATATYPE_MISMATCH.UNEXPECTED_NULL", parameters = Map( "exprName" -> "estimatedNumItems or numBits", "sqlExpr" -> "\"bloom_filter_agg(a, NULL, 5)\"" @@ -248,7 +248,7 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { } checkError( exception = exception5, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL", + condition = "DATATYPE_MISMATCH.UNEXPECTED_NULL", parameters = Map( "exprName" -> "estimatedNumItems or numBits", "sqlExpr" -> "\"bloom_filter_agg(a, 5, NULL)\"" @@ -268,7 +268,7 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { } checkError( exception = exception1, - errorClass = "DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE", + condition = "DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE", parameters = Map( "sqlExpr" -> "\"might_contain(1.0, 1)\"", "functionName" -> "`might_contain`", @@ -289,7 +289,7 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { } checkError( exception = exception2, - errorClass = "DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE", + condition = "DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE", parameters = Map( "sqlExpr" -> "\"might_contain(NULL, 0.1)\"", "functionName" -> "`might_contain`", @@ -314,7 +314,7 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { } checkError( exception = exception1, - errorClass = "DATATYPE_MISMATCH.BLOOM_FILTER_BINARY_OP_WRONG_TYPE", + condition = "DATATYPE_MISMATCH.BLOOM_FILTER_BINARY_OP_WRONG_TYPE", parameters = Map( "sqlExpr" -> "\"might_contain(CAST(a AS BINARY), CAST(5 AS BIGINT))\"", "functionName" -> "`might_contain`", @@ -335,7 +335,7 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { } checkError( exception = exception2, - errorClass = "DATATYPE_MISMATCH.BLOOM_FILTER_BINARY_OP_WRONG_TYPE", + condition = "DATATYPE_MISMATCH.BLOOM_FILTER_BINARY_OP_WRONG_TYPE", parameters = Map( "sqlExpr" -> "\"might_contain(scalarsubquery(a), CAST(5 AS BIGINT))\"", "functionName" -> "`might_contain`", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 3ac433f31288c..b1e53aec81637 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -157,7 +157,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils sql("CACHE TABLE tempView AS SELECT 1") } checkError(e, - errorClass = "TEMP_TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TEMP_TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> "`tempView`")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 59a566a3f2967..d3b11274fe1c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -67,7 +67,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { def assertLengthCheckFailure(func: () => Unit): Unit = { checkError( exception = intercept[SparkRuntimeException](func()), - errorClass = "EXCEED_LIMIT_LENGTH", + condition = "EXCEED_LIMIT_LENGTH", parameters = Map("limit" -> "5") ) } @@ -702,7 +702,7 @@ class BasicCharVarcharTestSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("""SELECT from_json('{"a": "str"}', 'a CHAR(5)')""") }, - errorClass = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING", + condition = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING", parameters = Map.empty, context = ExpectedContext( fragment = "from_json('{\"a\": \"str\"}', 'a CHAR(5)')", @@ -724,19 +724,19 @@ class BasicCharVarcharTestSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { spark.createDataFrame(df.collectAsList(), schema) }, - errorClass = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" + condition = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" ) checkError( exception = intercept[AnalysisException] { spark.createDataFrame(df.rdd, schema) }, - errorClass = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" + condition = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" ) checkError( exception = intercept[AnalysisException] { spark.createDataFrame(df.toJavaRDD, schema) }, - errorClass = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" + condition = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" ) withSQLConf((SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key, "true")) { val df1 = spark.createDataFrame(df.collectAsList(), schema) @@ -750,12 +750,12 @@ class BasicCharVarcharTestSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { spark.read.schema(new StructType().add("id", CharType(5))) }, - errorClass = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING") + condition = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING") checkError( exception = intercept[AnalysisException] { spark.read.schema("id char(5)") }, - errorClass = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" + condition = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" ) withSQLConf((SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key, "true")) { val ds = spark.range(10).map(_.toString) @@ -792,13 +792,13 @@ class BasicCharVarcharTestSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { spark.udf.register("testchar", () => "B", VarcharType(1)) }, - errorClass = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" + condition = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" ) checkError( exception = intercept[AnalysisException] { spark.udf.register("testchar2", (x: String) => x, VarcharType(1)) }, - errorClass = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" + condition = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" ) withSQLConf((SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key, "true")) { spark.udf.register("testchar", () => "B", VarcharType(1)) @@ -817,13 +817,13 @@ class BasicCharVarcharTestSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { spark.readStream.schema(new StructType().add("id", CharType(5))) }, - errorClass = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" + condition = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" ) checkError( exception = intercept[AnalysisException] { spark.readStream.schema("id char(5)") }, - errorClass = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" + condition = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" ) withSQLConf((SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key, "true")) { withTempPath { dir => @@ -845,7 +845,7 @@ class BasicCharVarcharTestSuite extends QueryTest with SharedSparkSession { val df = sql("SELECT * FROM t") checkError(exception = intercept[AnalysisException] { df.to(newSchema) - }, errorClass = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING", parameters = Map.empty) + }, condition = "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING", parameters = Map.empty) withSQLConf((SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key, "true")) { val df1 = df.to(newSchema) checkAnswer(df1, df.select("v", "c")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 2342722c0bb14..1d23774a51692 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.sql.Timestamp import org.apache.spark.{SparkFunSuite, SparkRuntimeException} +import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.variant.ParseJson import org.apache.spark.sql.internal.SqlApiConf @@ -46,7 +47,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * * @param inputEntry - List of all input entries that need to be generated * @param collationType - Flag defining collation type to use - * @return + * @return - List of data generated for expression instance creation */ def generateData( inputEntry: Seq[Any], @@ -54,23 +55,11 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi inputEntry.map(generateSingleEntry(_, collationType)) } - /** - * Helper function to generate single entry of data as a string. - * @param inputEntry - Single input entry that requires generation - * @param collationType - Flag defining collation type to use - * @return - */ - def generateDataAsStrings( - inputEntry: Seq[AbstractDataType], - collationType: CollationType): Seq[Any] = { - inputEntry.map(generateInputAsString(_, collationType)) - } - /** * Helper function to generate single entry of data. * @param inputEntry - Single input entry that requires generation * @param collationType - Flag defining collation type to use - * @return + * @return - Single input entry data */ def generateSingleEntry( inputEntry: Any, @@ -100,7 +89,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * * @param inputType - Single input literal type that requires generation * @param collationType - Flag defining collation type to use - * @return + * @return - Literal/Expression containing expression ready for evaluation */ def generateLiterals( inputType: AbstractDataType, @@ -116,6 +105,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } case BooleanType => Literal(true) case _: DatetimeType => Literal(Timestamp.valueOf("2009-07-30 12:58:59")) + case DecimalType => Literal((new Decimal).set(5)) case _: DecimalType => Literal((new Decimal).set(5)) case _: DoubleType => Literal(5.0) case IntegerType | NumericType | IntegralType => Literal(5) @@ -158,11 +148,15 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case MapType => val key = generateLiterals(StringTypeAnyCollation, collationType) val value = generateLiterals(StringTypeAnyCollation, collationType) - Literal.create(Map(key -> value)) + CreateMap(Seq(key, value)) case MapType(keyType, valueType, _) => val key = generateLiterals(keyType, collationType) val value = generateLiterals(valueType, collationType) - Literal.create(Map(key -> value)) + CreateMap(Seq(key, value)) + case AbstractMapType(keyType, valueType) => + val key = generateLiterals(keyType, collationType) + val value = generateLiterals(valueType, collationType) + CreateMap(Seq(key, value)) case StructType => CreateNamedStruct( Seq(Literal("start"), generateLiterals(StringTypeAnyCollation, collationType), @@ -174,7 +168,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * * @param inputType - Single input type that requires generation * @param collationType - Flag defining collation type to use - * @return + * @return - String representation of a input ready for SQL query */ def generateInputAsString( inputType: AbstractDataType, @@ -189,6 +183,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } case BooleanType => "True" case _: DatetimeType => "date'2016-04-08'" + case DecimalType => "5.0" case _: DecimalType => "5.0" case _: DoubleType => "5.0" case IntegerType | NumericType | IntegralType => "5" @@ -221,6 +216,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case MapType(keyType, valueType, _) => "map(" + generateInputAsString(keyType, collationType) + ", " + generateInputAsString(valueType, collationType) + ")" + case AbstractMapType(keyType, valueType) => + "map(" + generateInputAsString(keyType, collationType) + ", " + + generateInputAsString(valueType, collationType) + ")" case StructType => "named_struct( 'start', " + generateInputAsString(StringTypeAnyCollation, collationType) + ", 'end', " + generateInputAsString(StringTypeAnyCollation, collationType) + ")" @@ -234,7 +232,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * * @param inputType - Single input type that requires generation * @param collationType - Flag defining collation type to use - * @return + * @return - String representation for SQL query of a inputType */ def generateInputTypeAsStrings( inputType: AbstractDataType, @@ -244,6 +242,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case BinaryType => "BINARY" case BooleanType => "BOOLEAN" case _: DatetimeType => "DATE" + case DecimalType => "DECIMAL(2, 1)" case _: DecimalType => "DECIMAL(2, 1)" case _: DoubleType => "DOUBLE" case IntegerType | NumericType | IntegralType => "INT" @@ -275,6 +274,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case MapType(keyType, valueType, _) => "map<" + generateInputTypeAsStrings(keyType, collationType) + ", " + generateInputTypeAsStrings(valueType, collationType) + ">" + case AbstractMapType(keyType, valueType) => + "map<" + generateInputTypeAsStrings(keyType, collationType) + ", " + + generateInputTypeAsStrings(valueType, collationType) + ">" case StructType => "struct hasStringType(elementType) case TypeCollection(typeCollection) => typeCollection.exists(hasStringType) - case StructType => true case StructType(fields) => fields.exists(sf => hasStringType(sf.dataType)) case _ => false } @@ -310,7 +311,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * Helper function to replace expected parameters with expected input types. * @param inputTypes - Input types generated by ExpectsInputType.inputTypes * @param params - Parameters that are read from expression info - * @return + * @return - List of parameters where Expressions are replaced with input types */ def replaceExpressions(inputTypes: Seq[AbstractDataType], params: Seq[Class[_]]): Seq[Any] = { (inputTypes, params) match { @@ -325,7 +326,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi /** * Helper method to extract relevant expressions that can be walked over. - * @return + * @return - (List of relevant expressions that expect input, List of expressions to skip) */ def extractRelevantExpressions(): (Array[ExpressionInfo], List[String]) = { var expressionCounter = 0 @@ -384,6 +385,47 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi (funInfos, toSkip) } + /** + * Helper method to extract relevant expressions that can be walked over but are built with + * expression builder. + * + * @return - (List of expressions that are relevant builders, List of expressions to skip) + */ + def extractRelevantBuilders(): (Array[ExpressionInfo], List[String]) = { + var builderExpressionCounter = 0 + val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId => + spark.sessionState.catalog.lookupFunctionInfo(funcId) + }.filter(funInfo => { + // make sure that there is a constructor. + val cl = Utils.classForName(funInfo.getClassName) + cl.isAssignableFrom(classOf[ExpressionBuilder]) + }).filter(funInfo => { + builderExpressionCounter = builderExpressionCounter + 1 + val cl = Utils.classForName(funInfo.getClassName) + val method = cl.getMethod("build", + Utils.classForName("java.lang.String"), + Utils.classForName("scala.collection.Seq")) + var input: Seq[Expression] = Seq.empty + var i = 0 + for (_ <- 1 to 10) { + input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary) + try { + method.invoke(null, funInfo.getClassName, input).asInstanceOf[ExpectsInputTypes] + } + catch { + case _: Exception => i = i + 1 + } + } + if (i == 10) false + else true + }).toArray + + logInfo("Total number of expression that are built: " + builderExpressionCounter) + logInfo("Number of extracted expressions of relevance: " + funInfos.length) + + (funInfos, List()) + } + /** * Helper function to generate string of an expression suitable for execution. * @param expr - Expression that needs to be converted @@ -441,10 +483,36 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * 5) Otherwise, check if exceptions are the same */ test("SPARK-48280: Expression Walker for expression evaluation") { - val (funInfos, toSkip) = extractRelevantExpressions() + val (funInfosExpr, toSkip) = extractRelevantExpressions() + val (funInfosBuild, _) = extractRelevantBuilders() + val funInfos = funInfosExpr ++ funInfosBuild for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { - val cl = Utils.classForName(f.getClassName) + val TempCl = Utils.classForName(f.getClassName) + val cl = if (TempCl.isAssignableFrom(classOf[ExpressionBuilder])) { + val clTemp = Utils.classForName(f.getClassName) + val method = clTemp.getMethod("build", + Utils.classForName("java.lang.String"), + Utils.classForName("scala.collection.Seq")) + val instance = { + var input: Seq[Expression] = Seq.empty + var result: Expression = null + for (_ <- 1 to 10) { + input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary) + try { + val tempResult = method.invoke(null, f.getClassName, input) + if (result == null) result = tempResult.asInstanceOf[Expression] + } + catch { + case _: Exception => + } + } + result + } + instance.getClass + } + else Utils.classForName(f.getClassName) + val headConstructor = cl.getConstructors .zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => a._2)._1 val params = headConstructor.getParameters.map(p => p.getType) @@ -526,10 +594,36 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * 5) Otherwise, check if exceptions are the same */ test("SPARK-48280: Expression Walker for codeGen generation") { - val (funInfos, toSkip) = extractRelevantExpressions() + val (funInfosExpr, toSkip) = extractRelevantExpressions() + val (funInfosBuild, _) = extractRelevantBuilders() + val funInfos = funInfosExpr ++ funInfosBuild for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { - val cl = Utils.classForName(f.getClassName) + val TempCl = Utils.classForName(f.getClassName) + val cl = if (TempCl.isAssignableFrom(classOf[ExpressionBuilder])) { + val clTemp = Utils.classForName(f.getClassName) + val method = clTemp.getMethod("build", + Utils.classForName("java.lang.String"), + Utils.classForName("scala.collection.Seq")) + val instance = { + var input: Seq[Expression] = Seq.empty + var result: Expression = null + for (_ <- 1 to 10) { + input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary) + try { + val tempResult = method.invoke(null, f.getClassName, input) + if (result == null) result = tempResult.asInstanceOf[Expression] + } + catch { + case _: Exception => + } + } + result + } + instance.getClass + } + else Utils.classForName(f.getClassName) + val headConstructor = cl.getConstructors .zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => a._2)._1 val params = headConstructor.getParameters.map(p => p.getType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 7d0f6c401c0d6..941d5cd31db40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -686,7 +686,7 @@ class CollationSQLExpressionsSuite val testQuery = sql(query) testQuery.collect() }, - errorClass = "INVALID_FORMAT.MISMATCH_INPUT", + condition = "INVALID_FORMAT.MISMATCH_INPUT", parameters = Map("inputType" -> "\"STRING\"", "input" -> "xx", "format" -> "999") ) } @@ -982,6 +982,7 @@ class CollationSQLExpressionsSuite StringToMapTestCase("1/AX2/BX3/C", "x", "/", "UNICODE_CI", Map("1" -> "A", "2" -> "B", "3" -> "C")) ) + val unsupportedTestCase = StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null) testCases.foreach(t => { // Unit test. val text = Literal.create(t.text, StringType(t.collation)) @@ -996,6 +997,29 @@ class CollationSQLExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(dataType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select str_to_map('${unsupportedTestCase.text}', '${unsupportedTestCase.pairDelim}', " + + s"'${unsupportedTestCase.keyValueDelim}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"str_to_map('a:1,b:2,c:3' collate UNICODE_AI, " + + "'?' collate UNICODE_AI, '?' collate UNICODE_AI)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'a:1,b:2,c:3' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext( + fragment = "str_to_map('a:1,b:2,c:3', '?', '?')", + start = 7, + stop = 41)) + } } test("Support RaiseError misc expression with collation") { @@ -1015,7 +1039,7 @@ class CollationSQLExpressionsSuite exception = intercept[SparkRuntimeException] { sql(query).collect() }, - errorClass = "USER_RAISED_EXCEPTION", + condition = "USER_RAISED_EXCEPTION", parameters = Map("errorMessage" -> t.errorMessage) ) } @@ -1193,7 +1217,7 @@ class CollationSQLExpressionsSuite exception = intercept[AnalysisException] { sql("SELECT mask(collate('ab-CD-12-@$','UNICODE'),collate('X','UNICODE_CI'),'x','0','#')") }, - errorClass = "COLLATION_MISMATCH.EXPLICIT", + condition = "COLLATION_MISMATCH.EXPLICIT", parameters = Map("explicitTypes" -> "`string collate UNICODE`, `string collate UNICODE_CI`") ) } @@ -1385,7 +1409,7 @@ class CollationSQLExpressionsSuite val testQuery = sql(query) testQuery.collect() }, - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map("badRecord" -> "{\"a\":1,", "failFastMode" -> "FAILFAST") ) } @@ -1489,7 +1513,7 @@ class CollationSQLExpressionsSuite val testQuery = sql(query) testQuery.collect() }, - errorClass = "INVALID_VARIANT_CAST", + condition = "INVALID_VARIANT_CAST", parameters = Map("value" -> "\"Spark\"", "dataType" -> "\"INT\"") ) } @@ -1576,9 +1600,9 @@ class CollationSQLExpressionsSuite SchemaOfVariantTestCase("null", "UTF8_BINARY", "VOID"), SchemaOfVariantTestCase("[]", "UTF8_LCASE", "ARRAY"), SchemaOfVariantTestCase("[{\"a\":true,\"b\":0}]", "UNICODE", - "ARRAY>"), + "ARRAY>"), SchemaOfVariantTestCase("[{\"A\":\"x\",\"B\":-1.00}]", "UNICODE_CI", - "ARRAY>") + "ARRAY>") ) // Supported collations @@ -1607,9 +1631,9 @@ class CollationSQLExpressionsSuite SchemaOfVariantAggTestCase("('1'), ('2'), ('3')", "UTF8_BINARY", "BIGINT"), SchemaOfVariantAggTestCase("('true'), ('false'), ('true')", "UTF8_LCASE", "BOOLEAN"), SchemaOfVariantAggTestCase("('{\"a\": 1}'), ('{\"b\": true}'), ('{\"c\": 1.23}')", - "UNICODE", "STRUCT"), + "UNICODE", "OBJECT"), SchemaOfVariantAggTestCase("('{\"A\": \"x\"}'), ('{\"B\": 9.99}'), ('{\"C\": 0}')", - "UNICODE_CI", "STRUCT") + "UNICODE_CI", "OBJECT") ) // Supported collations @@ -1770,7 +1794,7 @@ class CollationSQLExpressionsSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = params, queryContext = Array( ExpectedContext(objectType = "", @@ -1821,7 +1845,7 @@ class CollationSQLExpressionsSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = params, queryContext = Array( ExpectedContext(objectType = "", @@ -1869,7 +1893,7 @@ class CollationSQLExpressionsSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = params, queryContext = Array( ExpectedContext(objectType = "", @@ -2319,7 +2343,7 @@ class CollationSQLExpressionsSuite exception = intercept[ExtendedAnalysisException] { sql(queryFail).collect() }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_STATIC_METHOD", + condition = "DATATYPE_MISMATCH.UNEXPECTED_STATIC_METHOD", parameters = Map( "methodName" -> "toHexString", "className" -> "java.lang.Integer", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala index 40cc6f19550d8..87dbbc65a3936 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala @@ -54,7 +54,7 @@ class CollationSQLRegexpSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"collate(ABC, UNICODE_CI) LIKE %b%\"", "paramIndex" -> "first", @@ -148,7 +148,7 @@ class CollationSQLRegexpSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"ilike(collate(ABC, UNICODE_CI), %b%)\"", "paramIndex" -> "first", @@ -188,7 +188,7 @@ class CollationSQLRegexpSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"likeall(collate(Foo, UNICODE_CI))\"", "paramIndex" -> "first", @@ -228,7 +228,7 @@ class CollationSQLRegexpSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"notlikeall(collate(Foo, UNICODE_CI))\"", "paramIndex" -> "first", @@ -268,7 +268,7 @@ class CollationSQLRegexpSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"likeany(collate(Foo, UNICODE_CI))\"", "paramIndex" -> "first", @@ -308,7 +308,7 @@ class CollationSQLRegexpSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"notlikeany(collate(Foo, UNICODE_CI))\"", "paramIndex" -> "first", @@ -348,7 +348,7 @@ class CollationSQLRegexpSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"RLIKE(collate(ABC, UNICODE_CI), .b.)\"", "paramIndex" -> "first", @@ -388,7 +388,7 @@ class CollationSQLRegexpSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"split(collate(ABC, UNICODE_CI), [b], -1)\"", "paramIndex" -> "first", @@ -429,7 +429,7 @@ class CollationSQLRegexpSuite exception = intercept[AnalysisException] { sql(s"SELECT regexp_replace(collate('ABCDE','$c1'), '.c.', collate('FFF','$c2'))") }, - errorClass = "COLLATION_MISMATCH.EXPLICIT", + condition = "COLLATION_MISMATCH.EXPLICIT", parameters = Map("explicitTypes" -> "`string`, `string collate UTF8_LCASE`") ) // Unsupported collations @@ -444,7 +444,7 @@ class CollationSQLRegexpSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"regexp_replace(collate(ABCDE, UNICODE_CI), .c., FFF, 1)\"", "paramIndex" -> "first", @@ -486,7 +486,7 @@ class CollationSQLRegexpSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"regexp_extract(collate(ABCDE, UNICODE_CI), .c., 0)\"", "paramIndex" -> "first", @@ -528,7 +528,7 @@ class CollationSQLRegexpSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"regexp_extract_all(collate(ABCDE, UNICODE_CI), .c., 0)\"", "paramIndex" -> "first", @@ -568,7 +568,7 @@ class CollationSQLRegexpSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"regexp_count(collate(ABCDE, UNICODE_CI), .c.)\"", "paramIndex" -> "first", @@ -608,7 +608,7 @@ class CollationSQLRegexpSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"regexp_substr(collate(ABCDE, UNICODE_CI), .c.)\"", "paramIndex" -> "first", @@ -648,7 +648,7 @@ class CollationSQLRegexpSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"regexp_instr(collate(ABCDE, UNICODE_CI), .c., 0)\"", "paramIndex" -> "first", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 412c003a0dbaa..fe9872ddaf575 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -98,6 +98,7 @@ class CollationStringExpressionsSuite SplitPartTestCase("1a2", "A", 2, "UTF8_LCASE", "2"), SplitPartTestCase("1a2", "A", 2, "UNICODE_CI", "2") ) + val unsupportedTestCase = SplitPartTestCase("1a2", "a", 2, "UNICODE_AI", "2") testCases.foreach(t => { // Unit test. val str = Literal.create(t.str, StringType(t.collation)) @@ -111,6 +112,26 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select split_part('${unsupportedTestCase.str}', '${unsupportedTestCase.delimiter}', " + + s"${unsupportedTestCase.partNum})" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"split_part('1a2' collate UNICODE_AI, 'a' collate UNICODE_AI, 2)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'1a2' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "split_part('1a2', 'a', 2)", start = 7, stop = 31) + ) + } } test("Support `StringSplitSQL` string expression with collation") { @@ -141,7 +162,7 @@ class CollationStringExpressionsSuite Cast(Literal.create("a"), StringType("UTF8_LCASE"))) CollationTypeCasts.transform(expr) }, - errorClass = "COLLATION_MISMATCH.IMPLICIT", + condition = "COLLATION_MISMATCH.IMPLICIT", sqlState = "42P21", parameters = Map.empty ) @@ -152,7 +173,7 @@ class CollationStringExpressionsSuite Collate(Literal.create("a"), "UTF8_LCASE")) CollationTypeCasts.transform(expr) }, - errorClass = "COLLATION_MISMATCH.EXPLICIT", + condition = "COLLATION_MISMATCH.EXPLICIT", sqlState = "42P21", parameters = Map("explicitTypes" -> "`string`, `string collate UTF8_LCASE`") ) @@ -166,6 +187,7 @@ class CollationStringExpressionsSuite ContainsTestCase("abcde", "FGH", "UTF8_LCASE", false), ContainsTestCase("abcde", "BCD", "UNICODE_CI", true) ) + val unsupportedTestCase = ContainsTestCase("abcde", "A", "UNICODE_AI", false) testCases.foreach(t => { // Unit test. val left = Literal.create(t.left, StringType(t.collation)) @@ -178,6 +200,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select contains('${unsupportedTestCase.left}', '${unsupportedTestCase.right}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"contains('abcde' collate UNICODE_AI, 'A' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'abcde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "contains('abcde', 'A')", start = 7, stop = 28) + ) + } } test("Support `SubstringIndex` expression with collation") { @@ -194,6 +235,7 @@ class CollationStringExpressionsSuite SubstringIndexTestCase("aaaaaaaaaa", "aa", 2, "UNICODE", "a"), SubstringIndexTestCase("wwwmapacheMorg", "M", -2, "UNICODE_CI", "apacheMorg") ) + val unsupportedTestCase = SubstringIndexTestCase("abacde", "a", 2, "UNICODE_AI", "cde") testCases.foreach(t => { // Unit test. val strExpr = Literal.create(t.strExpr, StringType(t.collation)) @@ -207,6 +249,29 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select substring_index('${unsupportedTestCase.strExpr}', " + + s"'${unsupportedTestCase.delimExpr}', ${unsupportedTestCase.countExpr})" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"substring_index('abacde' collate UNICODE_AI, " + + "'a' collate UNICODE_AI, 2)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'abacde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext( + fragment = "substring_index('abacde', 'a', 2)", + start = 7, + stop = 39)) + } } test("Support `StringInStr` string expression with collation") { @@ -219,6 +284,7 @@ class CollationStringExpressionsSuite StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE_CI", 8), StringInStrTestCase("abİo12", "i̇o", "UNICODE_CI", 3) ) + val unsupportedTestCase = StringInStrTestCase("a", "abcde", "UNICODE_AI", 0) testCases.foreach(t => { // Unit test. val str = Literal.create(t.str, StringType(t.collation)) @@ -231,6 +297,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select instr('${unsupportedTestCase.str}', '${unsupportedTestCase.substr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"instr('a' collate UNICODE_AI, 'abcde' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'a' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "instr('a', 'abcde')", start = 7, stop = 25) + ) + } } test("Support `FindInSet` string expression with collation") { @@ -264,6 +349,7 @@ class CollationStringExpressionsSuite StartsWithTestCase("abcde", "FGH", "UTF8_LCASE", false), StartsWithTestCase("abcde", "ABC", "UNICODE_CI", true) ) + val unsupportedTestCase = StartsWithTestCase("abcde", "A", "UNICODE_AI", false) testCases.foreach(t => { // Unit test. val left = Literal.create(t.left, StringType(t.collation)) @@ -276,6 +362,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select startswith('${unsupportedTestCase.left}', '${unsupportedTestCase.right}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"startswith('abcde' collate UNICODE_AI, 'A' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'abcde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "startswith('abcde', 'A')", start = 7, stop = 30) + ) + } } test("Support `StringTranslate` string expression with collation") { @@ -291,6 +396,7 @@ class CollationStringExpressionsSuite StringTranslateTestCase("Translate", "Rn", "\u0000\u0000", "UNICODE", "Traslate"), StringTranslateTestCase("Translate", "Rn", "1234", "UNICODE_CI", "T1a2slate") ) + val unsupportedTestCase = StringTranslateTestCase("ABC", "AB", "12", "UNICODE_AI", "12C") testCases.foreach(t => { // Unit test. val srcExpr = Literal.create(t.srcExpr, StringType(t.collation)) @@ -304,6 +410,27 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select translate('${unsupportedTestCase.srcExpr}', " + + s"'${unsupportedTestCase.matchingExpr}', '${unsupportedTestCase.replaceExpr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"translate('ABC' collate UNICODE_AI, 'AB' collate UNICODE_AI, " + + "'12' collate UNICODE_AI)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'ABC' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "translate('ABC', 'AB', '12')", start = 7, stop = 34) + ) + } } test("Support `StringReplace` string expression with collation") { @@ -321,6 +448,7 @@ class CollationStringExpressionsSuite StringReplaceTestCase("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy"), StringReplaceTestCase("abİo12i̇o", "i̇o", "xx", "UNICODE_CI", "abxx12xx") ) + val unsupportedTestCase = StringReplaceTestCase("abcde", "A", "B", "UNICODE_AI", "abcde") testCases.foreach(t => { // Unit test. val srcExpr = Literal.create(t.srcExpr, StringType(t.collation)) @@ -334,6 +462,27 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select replace('${unsupportedTestCase.srcExpr}', '${unsupportedTestCase.searchExpr}', " + + s"'${unsupportedTestCase.replaceExpr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"replace('abcde' collate UNICODE_AI, 'A' collate UNICODE_AI, " + + "'B' collate UNICODE_AI)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'abcde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "replace('abcde', 'A', 'B')", start = 7, stop = 32) + ) + } } test("Support `EndsWith` string expression with collation") { @@ -344,6 +493,7 @@ class CollationStringExpressionsSuite EndsWithTestCase("abcde", "FGH", "UTF8_LCASE", false), EndsWithTestCase("abcde", "CDE", "UNICODE_CI", true) ) + val unsupportedTestCase = EndsWithTestCase("abcde", "A", "UNICODE_AI", false) testCases.foreach(t => { // Unit test. val left = Literal.create(t.left, StringType(t.collation)) @@ -355,6 +505,25 @@ class CollationStringExpressionsSuite checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) } + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select endswith('${unsupportedTestCase.left}', '${unsupportedTestCase.right}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"endswith('abcde' collate UNICODE_AI, 'A' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'abcde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "endswith('abcde', 'A')", start = 7, stop = 28) + ) + } }) } @@ -1097,6 +1266,7 @@ class CollationStringExpressionsSuite StringLocateTestCase("aa", "Aaads", 0, "UNICODE_CI", 0), StringLocateTestCase("界x", "test大千世界X大千世界", 1, "UNICODE_CI", 8) ) + val unsupportedTestCase = StringLocateTestCase("aa", "Aaads", 0, "UNICODE_AI", 1) testCases.foreach(t => { // Unit test. val substr = Literal.create(t.substr, StringType(t.collation)) @@ -1110,6 +1280,26 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select locate('${unsupportedTestCase.substr}', '${unsupportedTestCase.str}', " + + s"${unsupportedTestCase.start})" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"locate('aa' collate UNICODE_AI, 'Aaads' collate UNICODE_AI, 0)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'aa' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "locate('aa', 'Aaads', 0)", start = 7, stop = 30) + ) + } } test("Support `StringTrimLeft` string expression with collation") { @@ -1124,6 +1314,7 @@ class CollationStringExpressionsSuite StringTrimLeftTestCase("xxasdxx", Some("y"), "UNICODE", "xxasdxx"), StringTrimLeftTestCase(" asd ", None, "UNICODE_CI", "asd ") ) + val unsupportedTestCase = StringTrimLeftTestCase("xxasdxx", Some("x"), "UNICODE_AI", null) testCases.foreach(t => { // Unit test. val srcStr = Literal.create(t.srcStr, StringType(t.collation)) @@ -1137,6 +1328,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val trimString = s"'${unsupportedTestCase.trimStr.get}', " + val query = s"select ltrim($trimString'${unsupportedTestCase.srcStr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"TRIM(LEADING 'x' collate UNICODE_AI FROM 'xxasdxx' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'xxasdxx' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "ltrim('x', 'xxasdxx')", start = 7, stop = 27) + ) + } } test("Support `StringTrimRight` string expression with collation") { @@ -1151,6 +1361,7 @@ class CollationStringExpressionsSuite StringTrimRightTestCase("xxasdxx", Some("y"), "UNICODE", "xxasdxx"), StringTrimRightTestCase(" asd ", None, "UNICODE_CI", " asd") ) + val unsupportedTestCase = StringTrimRightTestCase("xxasdxx", Some("x"), "UNICODE_AI", "xxasd") testCases.foreach(t => { // Unit test. val srcStr = Literal.create(t.srcStr, StringType(t.collation)) @@ -1164,6 +1375,26 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val trimString = s"'${unsupportedTestCase.trimStr.get}', " + val query = s"select rtrim($trimString'${unsupportedTestCase.srcStr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"TRIM(TRAILING 'x' collate UNICODE_AI FROM 'xxasdxx'" + + " collate UNICODE_AI)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'xxasdxx' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "rtrim('x', 'xxasdxx')", start = 7, stop = 27) + ) + } } test("Support `StringTrim` string expression with collation") { @@ -1178,6 +1409,7 @@ class CollationStringExpressionsSuite StringTrimTestCase("xxasdxx", Some("y"), "UNICODE", "xxasdxx"), StringTrimTestCase(" asd ", None, "UNICODE_CI", "asd") ) + val unsupportedTestCase = StringTrimTestCase("xxasdxx", Some("x"), "UNICODE_AI", "asd") testCases.foreach(t => { // Unit test. val srcStr = Literal.create(t.srcStr, StringType(t.collation)) @@ -1191,6 +1423,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val trimString = s"'${unsupportedTestCase.trimStr.get}', " + val query = s"select trim($trimString'${unsupportedTestCase.srcStr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"TRIM(BOTH 'x' collate UNICODE_AI FROM 'xxasdxx' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'xxasdxx' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "trim('x', 'xxasdxx')", start = 7, stop = 26) + ) + } } test("Support `StringTrimBoth` string expression with collation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 5e7feec149c97..632b9305feb57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAg import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.{SqlApiConf, SQLConf} -import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION import org.apache.spark.sql.types.{ArrayType, MapType, StringType, StructField, StructType} class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { @@ -91,7 +90,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"select collate($args)") }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", sqlState = "42605", parameters = Map( "functionName" -> "`collate`", @@ -106,7 +105,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { test("collate function invalid collation data type") { checkError( exception = intercept[AnalysisException](sql("select collate('abc', 123)")), - errorClass = "UNEXPECTED_INPUT_TYPE", + condition = "UNEXPECTED_INPUT_TYPE", sqlState = "42K09", Map( "functionName" -> "`collate`", @@ -122,7 +121,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkError( exception = intercept[AnalysisException] { sql("select collate('abc', cast(null as string))") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL", + condition = "DATATYPE_MISMATCH.UNEXPECTED_NULL", sqlState = "42K09", Map("exprName" -> "`collation`", "sqlExpr" -> "\"CAST(NULL AS STRING)\""), context = ExpectedContext( @@ -133,7 +132,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { test("collate function invalid input data type") { checkError( exception = intercept[ExtendedAnalysisException] { sql(s"select collate(1, 'UTF8_BINARY')") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = "42K09", parameters = Map( "sqlExpr" -> "\"collate(1, UTF8_BINARY)\"", @@ -152,21 +151,25 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { test("invalid collation name throws exception") { checkError( exception = intercept[SparkException] { sql("select 'aaa' collate UTF8_BS") }, - errorClass = "COLLATION_INVALID_NAME", + condition = "COLLATION_INVALID_NAME", sqlState = "42704", parameters = Map("collationName" -> "UTF8_BS", "proposals" -> "UTF8_LCASE")) } test("disable bucketing on collated string column") { - spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) def createTable(bucketColumns: String*): Unit = { val tableName = "test_partition_tbl" withTable(tableName) { sql( s""" - |CREATE TABLE $tableName - |(id INT, c1 STRING COLLATE UNICODE, c2 string) - |USING parquet + |CREATE TABLE $tableName ( + | id INT, + | c1 STRING COLLATE UNICODE, + | c2 STRING, + | struct_col STRUCT, + | array_col ARRAY, + | map_col MAP + |) USING parquet |CLUSTERED BY (${bucketColumns.mkString(",")}) |INTO 4 BUCKETS""".stripMargin ) @@ -177,15 +180,44 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { createTable("c2") createTable("id", "c2") - Seq(Seq("c1"), Seq("c1", "id"), Seq("c1", "c2")).foreach { bucketColumns => + val failBucketingColumns = Seq( + Seq("c1"), Seq("c1", "id"), Seq("c1", "c2"), + Seq("struct_col"), Seq("array_col"), Seq("map_col") + ) + + failBucketingColumns.foreach { bucketColumns => checkError( exception = intercept[AnalysisException] { createTable(bucketColumns: _*) }, - errorClass = "INVALID_BUCKET_COLUMN_DATA_TYPE", - parameters = Map("type" -> "\"STRING COLLATE UNICODE\"") - ); + condition = "INVALID_BUCKET_COLUMN_DATA_TYPE", + parameters = Map("type" -> ".*STRING COLLATE UNICODE.*"), + matchPVals = true + ) + } + } + + test("check difference betweeen SR_AI and SR_Latn_AI collations") { + // scalastyle:off nonascii + Seq( + ("c", "ć"), + ("c", "č"), + ("ć", "č"), + ("C", "Ć"), + ("C", "Č"), + ("Ć", "Č"), + ("s", "š"), + ("S", "Š"), + ("z", "ž"), + ("Z", "Ž") + ).foreach { + case (c1, c2) => + // SR_Latn_AI + checkAnswer(sql(s"SELECT '$c1' = '$c2' COLLATE SR_Latn_AI"), Row(false)) + // SR_AI + checkAnswer(sql(s"SELECT '$c1' = '$c2' COLLATE SR_AI"), Row(true)) } + // scalastyle:on nonascii } test("equality check respects collation") { @@ -248,7 +280,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { spark.sql(s"SELECT contains(collate('$left', '$leftCollationName')," + s"collate('$right', '$rightCollationName'))") }, - errorClass = "COLLATION_MISMATCH.EXPLICIT", + condition = "COLLATION_MISMATCH.EXPLICIT", sqlState = "42P21", parameters = Map( "explicitTypes" -> @@ -262,7 +294,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { spark.sql(s"SELECT startsWith(collate('$left', '$leftCollationName')," + s"collate('$right', '$rightCollationName'))") }, - errorClass = "COLLATION_MISMATCH.EXPLICIT", + condition = "COLLATION_MISMATCH.EXPLICIT", sqlState = "42P21", parameters = Map( "explicitTypes" -> @@ -276,7 +308,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { spark.sql(s"SELECT endsWith(collate('$left', '$leftCollationName')," + s"collate('$right', '$rightCollationName'))") }, - errorClass = "COLLATION_MISMATCH.EXPLICIT", + condition = "COLLATION_MISMATCH.EXPLICIT", sqlState = "42P21", parameters = Map( "explicitTypes" -> @@ -455,7 +487,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sql(s"SELECT c1 FROM $tableName " + s"WHERE c1 = SUBSTR(COLLATE('a', 'UNICODE'), 0)") }, - errorClass = "COLLATION_MISMATCH.IMPLICIT", + condition = "COLLATION_MISMATCH.IMPLICIT", parameters = Map.empty ) @@ -479,7 +511,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"SELECT c1 || c2 FROM $tableName") }, - errorClass = "COLLATION_MISMATCH.IMPLICIT" + condition = "COLLATION_MISMATCH.IMPLICIT" ) @@ -494,7 +526,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"SELECT c1 FROM $tableName WHERE c1 = c3") }, - errorClass = "COLLATION_MISMATCH.IMPLICIT" + condition = "COLLATION_MISMATCH.IMPLICIT" ) // different explicit collations are set @@ -506,7 +538,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { |WHERE COLLATE('a', 'UTF8_BINARY') = COLLATE('a', 'UNICODE')""" .stripMargin) }, - errorClass = "COLLATION_MISMATCH.EXPLICIT", + condition = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( "explicitTypes" -> "`string`, `string collate UNICODE`" ) @@ -518,7 +550,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sql(s"SELECT c1 FROM $tableName WHERE c1 IN " + "(COLLATE('a', 'UTF8_BINARY'), COLLATE('b', 'UNICODE'))") }, - errorClass = "COLLATION_MISMATCH.EXPLICIT", + condition = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( "explicitTypes" -> "`string`, `string collate UNICODE`" ) @@ -528,7 +560,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sql(s"SELECT c1 FROM $tableName WHERE COLLATE(c1, 'UNICODE') IN " + "(COLLATE('a', 'UTF8_BINARY'))") }, - errorClass = "COLLATION_MISMATCH.EXPLICIT", + condition = "COLLATION_MISMATCH.EXPLICIT", parameters = Map( "explicitTypes" -> "`string collate UNICODE`, `string`" ) @@ -540,7 +572,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"SELECT c1 FROM $tableName WHERE c1 || c3 = 'aa'") }, - errorClass = "COLLATION_MISMATCH.IMPLICIT" + condition = "COLLATION_MISMATCH.IMPLICIT" ) // concat on different implicit collations should succeed, @@ -549,7 +581,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"SELECT * FROM $tableName ORDER BY c1 || c3") }, - errorClass = "COLLATION_MISMATCH.IMPLICIT" + condition = "COLLATION_MISMATCH.IMPLICIT" ) // concat + in @@ -566,14 +598,14 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"SELECT * FROM $tableName WHERE contains(c1||c3, 'a')") }, - errorClass = "COLLATION_MISMATCH.IMPLICIT" + condition = "COLLATION_MISMATCH.IMPLICIT" ) checkError( exception = intercept[AnalysisException] { sql(s"SELECT array('A', 'a' COLLATE UNICODE) == array('b' COLLATE UNICODE_CI)") }, - errorClass = "COLLATION_MISMATCH.IMPLICIT" + condition = "COLLATION_MISMATCH.IMPLICIT" ) checkAnswer(sql("SELECT array_join(array('a', 'b' collate UNICODE), 'c' collate UNICODE_CI)"), @@ -592,7 +624,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql("select map('a' COLLATE UTF8_LCASE, 'b', 'c')") }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map("functionName" -> "`map`", "expectedNum" -> "2n (n > 0)", "actualNum" -> "3", "docroot" -> "https://spark.apache.org/docs/latest") ) @@ -602,7 +634,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql("select map('a' COLLATE UTF8_LCASE, 'b', 'c' COLLATE UNICODE, 'c')") }, - errorClass = "COLLATION_MISMATCH.EXPLICIT", + condition = "COLLATION_MISMATCH.EXPLICIT", sqlState = "42P21", parameters = Map( "explicitTypes" -> @@ -722,7 +754,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"CREATE TABLE $newTableName AS SELECT c1 || c2 FROM $tableName") }, - errorClass = "COLLATION_MISMATCH.IMPLICIT") + condition = "COLLATION_MISMATCH.IMPLICIT") } } } @@ -760,7 +792,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } test("disable partition on collated string column") { - spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) def createTable(partitionColumns: String*): Unit = { val tableName = "test_partition_tbl" withTable(tableName) { @@ -784,7 +815,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { createTable(partitionColumns: _*) }, - errorClass = "INVALID_PARTITION_COLUMN_DATA_TYPE", + condition = "INVALID_PARTITION_COLUMN_DATA_TYPE", parameters = Map("type" -> "\"STRING COLLATE UNICODE\"") ); } @@ -821,7 +852,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { |USING $v2Source |""".stripMargin) }, - errorClass = "UNSUPPORTED_EXPRESSION_GENERATED_COLUMN", + condition = "UNSUPPORTED_EXPRESSION_GENERATED_COLUMN", parameters = Map( "fieldName" -> "c2", "expressionStr" -> "SUBSTRING(c1, 0, 1)", @@ -839,7 +870,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { |USING $v2Source |""".stripMargin) }, - errorClass = "UNSUPPORTED_EXPRESSION_GENERATED_COLUMN", + condition = "UNSUPPORTED_EXPRESSION_GENERATED_COLUMN", parameters = Map( "fieldName" -> "c2", "expressionStr" -> "LOWER(c1)", @@ -857,7 +888,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { |USING $v2Source |""".stripMargin) }, - errorClass = "UNSUPPORTED_EXPRESSION_GENERATED_COLUMN", + condition = "UNSUPPORTED_EXPRESSION_GENERATED_COLUMN", parameters = Map( "fieldName" -> "c2", "expressionStr" -> "UCASE(struct1.a)", @@ -875,7 +906,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkError( exception = intercept[ParseException] (sql("SELECT cast(1 as string collate unicode)")), - errorClass = "UNSUPPORTED_DATATYPE", + condition = "UNSUPPORTED_DATATYPE", parameters = Map( "typeName" -> toSQLType(StringType("UNICODE"))), context = @@ -885,7 +916,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkError( exception = intercept[ParseException] (sql("SELECT 'A' :: string collate unicode")), - errorClass = "UNSUPPORTED_DATATYPE", + condition = "UNSUPPORTED_DATATYPE", parameters = Map( "typeName" -> toSQLType(StringType("UNICODE"))), context = ExpectedContext(fragment = s"'A' :: string collate unicode", start = 7, stop = 35) @@ -898,7 +929,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkError( exception = intercept[ParseException] (sql("SELECT cast(1 as string collate unicode)")), - errorClass = "UNSUPPORTED_DATATYPE", + condition = "UNSUPPORTED_DATATYPE", parameters = Map( "typeName" -> toSQLType(StringType("UNICODE"))), context = @@ -950,7 +981,10 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { withTable(table) { sql(s"create table $table (a array) using parquet") sql(s"insert into $table values (array('aaa')), (array('AAA'))") - checkAnswer(sql(s"select distinct a from $table"), Seq(Row(Seq("aaa")))) + val result = sql(s"select distinct a from $table").collect() + assert(result.length === 1) + val data = result.head.getSeq[String](0) + assert(data === Array("aaa") || data === Array("AAA")) } // map doesn't support aggregation withTable(table) { @@ -958,7 +992,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val query = s"select distinct m from $table" checkError( exception = intercept[ExtendedAnalysisException](sql(query)), - errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE", + condition = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE", parameters = Map( "colName" -> "`m`", "dataType" -> toSQLType(MapType( @@ -971,7 +1005,10 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { withTable(table) { sql(s"create table $table (s struct) using parquet") sql(s"insert into $table values (named_struct('fld', 'aaa')), (named_struct('fld', 'AAA'))") - checkAnswer(sql(s"select s.fld from $table group by s"), Seq(Row("aaa"))) + val result = sql(s"select s.fld from $table group by s").collect() + assert(result.length === 1) + val data = result.head.getString(0) + assert(data === "aaa" || data === "AAA") } } @@ -1000,7 +1037,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val ctx = s"$tableLeft.m = $tableRight.m" checkError( exception = intercept[AnalysisException](sql(query)), - errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + condition = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", parameters = Map( "functionName" -> "`=`", "dataType" -> toSQLType(MapType( @@ -1127,7 +1164,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val ctx = "m" checkError( exception = intercept[AnalysisException](sql(query)), - errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + condition = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", parameters = Map( "functionName" -> "`sortorder`", "dataType" -> s"\"MAP\"", @@ -1180,7 +1217,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val query = s"select $ctx" checkError( exception = intercept[AnalysisException](sql(query)), - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"map(collate(aaa, utf8_lcase), 1, collate(AAA, utf8_lcase), 2)[AaA]\"", "paramIndex" -> "second", @@ -1621,4 +1658,77 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } } + + test("TVF collations()") { + assert(sql("SELECT * FROM collations()").collect().length >= 562) + + // verify that the output ordering is as expected (UTF8_BINARY, UTF8_LCASE, etc.) + val df = sql("SELECT * FROM collations() limit 10") + checkAnswer(df, + Seq(Row("SYSTEM", "BUILTIN", "UTF8_BINARY", null, null, + "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", null), + Row("SYSTEM", "BUILTIN", "UTF8_LCASE", null, null, + "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", null), + Row("SYSTEM", "BUILTIN", "UNICODE", "", "", + "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "UNICODE_AI", "", "", + "ACCENT_INSENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "UNICODE_CI", "", "", + "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "UNICODE_CI_AI", "", "", + "ACCENT_INSENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "af", "Afrikaans", "", + "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "af_AI", "Afrikaans", "", + "ACCENT_INSENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "af_CI", "Afrikaans", "", + "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "af_CI_AI", "Afrikaans", "", + "ACCENT_INSENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"))) + + checkAnswer(sql("SELECT * FROM collations() WHERE NAME LIKE '%UTF8_BINARY%'"), + Row("SYSTEM", "BUILTIN", "UTF8_BINARY", null, null, + "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", null)) + + checkAnswer(sql("SELECT * FROM collations() WHERE NAME LIKE '%zh_Hant_HKG%'"), + Seq(Row("SYSTEM", "BUILTIN", "zh_Hant_HKG", "Chinese", "Hong Kong SAR China", + "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "zh_Hant_HKG_AI", "Chinese", "Hong Kong SAR China", + "ACCENT_INSENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "zh_Hant_HKG_CI", "Chinese", "Hong Kong SAR China", + "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "zh_Hant_HKG_CI_AI", "Chinese", "Hong Kong SAR China", + "ACCENT_INSENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"))) + + checkAnswer(sql("SELECT * FROM collations() WHERE COUNTRY = 'Singapore'"), + Seq(Row("SYSTEM", "BUILTIN", "zh_Hans_SGP", "Chinese", "Singapore", + "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "zh_Hans_SGP_AI", "Chinese", "Singapore", + "ACCENT_INSENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "zh_Hans_SGP_CI", "Chinese", "Singapore", + "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "zh_Hans_SGP_CI_AI", "Chinese", "Singapore", + "ACCENT_INSENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"))) + + checkAnswer(sql("SELECT * FROM collations() WHERE LANGUAGE = 'English' " + + "and COUNTRY = 'United States'"), + Seq(Row("SYSTEM", "BUILTIN", "en_USA", "English", "United States", + "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "en_USA_AI", "English", "United States", + "ACCENT_INSENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "en_USA_CI", "English", "United States", + "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "en_USA_CI_AI", "English", "United States", + "ACCENT_INSENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"))) + + checkAnswer(sql("SELECT NAME, LANGUAGE, ACCENT_SENSITIVITY, CASE_SENSITIVITY " + + "FROM collations() WHERE COUNTRY = 'United States'"), + Seq(Row("en_USA", "English", "ACCENT_SENSITIVE", "CASE_SENSITIVE"), + Row("en_USA_AI", "English", "ACCENT_INSENSITIVE", "CASE_SENSITIVE"), + Row("en_USA_CI", "English", "ACCENT_SENSITIVE", "CASE_INSENSITIVE"), + Row("en_USA_CI_AI", "English", "ACCENT_INSENSITIVE", "CASE_INSENSITIVE"))) + + checkAnswer(sql("SELECT NAME FROM collations() WHERE ICU_VERSION is null"), + Seq(Row("UTF8_BINARY"), Row("UTF8_LCASE"))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 68a7a4b8b2412..9cd35e527df57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -455,7 +455,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df2.filter($"a".isin($"b")) }, - errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", parameters = Map( "functionName" -> "`in`", "dataType" -> "[\"INT\", \"ARRAY\"]", @@ -523,7 +523,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df2.filter($"a".isInCollection(Seq($"b"))) }, - errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", parameters = Map( "functionName" -> "`in`", "dataType" -> "[\"INT\", \"ARRAY\"]", @@ -734,7 +734,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(s"SELECT *, $f() FROM tab1 JOIN tab2 ON tab1.id = tab2.id") }, - errorClass = "MULTI_SOURCES_UNSUPPORTED_FOR_EXPRESSION", + condition = "MULTI_SOURCES_UNSUPPORTED_FOR_EXPRESSION", parameters = Map("expr" -> s""""$f()""""), context = ExpectedContext( fragment = s"$f()", @@ -753,7 +753,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(stmt) }, - errorClass = "MULTI_SOURCES_UNSUPPORTED_FOR_EXPRESSION", + condition = "MULTI_SOURCES_UNSUPPORTED_FOR_EXPRESSION", parameters = Map("expr" -> """"input_file_name()""""), context = ExpectedContext( fragment = s"input_file_name()", @@ -1055,7 +1055,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { testData.withColumn("key", $"key".withField("a", lit(2))) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"update_fields(key, WithField(2))\"", "paramIndex" -> "first", @@ -1087,14 +1087,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { structLevel2.withColumn("a", $"a".withField("x.b", lit(2))) }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`x`", "fields" -> "`a`")) checkError( exception = intercept[AnalysisException] { structLevel3.withColumn("a", $"a".withField("a.x.b", lit(2))) }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`x`", "fields" -> "`a`")) } @@ -1103,7 +1103,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { structLevel1.withColumn("a", $"a".withField("b.a", lit(2))) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"update_fields(a.b, WithField(2))\"", "paramIndex" -> "first", @@ -1129,7 +1129,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { structLevel2.withColumn("a", $"a".withField("a.b", lit(2))) }, - errorClass = "AMBIGUOUS_REFERENCE_TO_FIELDS", + condition = "AMBIGUOUS_REFERENCE_TO_FIELDS", sqlState = "42000", parameters = Map("field" -> "`a`", "count" -> "2") ) @@ -1532,7 +1532,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.withColumn("a", $"a".withField("a.b.e.f", lit(2))) }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`a`", "fields" -> "`a`.`b`")) } @@ -1644,14 +1644,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { mixedCaseStructLevel2.withColumn("a", $"a".withField("A.a", lit(2))) }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`A`", "fields" -> "`a`, `B`")) checkError( exception = intercept[AnalysisException] { mixedCaseStructLevel2.withColumn("a", $"a".withField("b.a", lit(2))) }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`b`", "fields" -> "`a`, `B`")) } } @@ -1687,7 +1687,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col") .select($"struct_col".withField("a.c", lit(3))) }, - errorClass = "AMBIGUOUS_REFERENCE_TO_FIELDS", + condition = "AMBIGUOUS_REFERENCE_TO_FIELDS", sqlState = "42000", parameters = Map("field" -> "`a`", "count" -> "2") ) @@ -1854,7 +1854,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { testData.withColumn("key", $"key".dropFields("a")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"update_fields(key, dropfield())\"", "paramIndex" -> "first", @@ -1878,14 +1878,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { structLevel2.withColumn("a", $"a".dropFields("x.b")) }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`x`", "fields" -> "`a`")) checkError( exception = intercept[AnalysisException] { structLevel3.withColumn("a", $"a".dropFields("a.x.b")) }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`x`", "fields" -> "`a`")) } @@ -1894,7 +1894,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { structLevel1.withColumn("a", $"a".dropFields("b.a")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"update_fields(a.b, dropfield())\"", "paramIndex" -> "first", @@ -1920,7 +1920,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { structLevel2.withColumn("a", $"a".dropFields("a.b")) }, - errorClass = "AMBIGUOUS_REFERENCE_TO_FIELDS", + condition = "AMBIGUOUS_REFERENCE_TO_FIELDS", sqlState = "42000", parameters = Map("field" -> "`a`", "count" -> "2") ) @@ -1968,7 +1968,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { structLevel1.withColumn("a", $"a".dropFields("a", "b", "c")) }, - errorClass = "DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS", + condition = "DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS", parameters = Map("sqlExpr" -> "\"update_fields(a, dropfield(), dropfield(), dropfield())\""), context = ExpectedContext( fragment = "dropFields", @@ -2158,14 +2158,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { mixedCaseStructLevel2.withColumn("a", $"a".dropFields("A.a")) }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`A`", "fields" -> "`a`, `B`")) checkError( exception = intercept[AnalysisException] { mixedCaseStructLevel2.withColumn("a", $"a".dropFields("b.a")) }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`b`", "fields" -> "`a`, `B`")) } } @@ -2243,7 +2243,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { sql("SELECT named_struct('a', 1, 'b', 2) struct_col") .select($"struct_col".dropFields("a", "b")) }, - errorClass = "DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS", + condition = "DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS", parameters = Map("sqlExpr" -> "\"update_fields(struct_col, dropfield(), dropfield())\""), context = ExpectedContext( fragment = "dropFields", @@ -2270,7 +2270,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col") .select($"struct_col".dropFields("a.c")) }, - errorClass = "AMBIGUOUS_REFERENCE_TO_FIELDS", + condition = "AMBIGUOUS_REFERENCE_TO_FIELDS", sqlState = "42000", parameters = Map("field" -> "`a`", "count" -> "2") ) @@ -2420,7 +2420,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { structLevel1.select($"a".withField("d", lit(4)).withField("e", $"a.d" + 1).as("a")) }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`d`", "fields" -> "`a`, `b`, `c`"), context = ExpectedContext( fragment = "$", @@ -2476,7 +2476,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { .select($"a".dropFields("c").as("a")) .select($"a".withField("z", $"a.c")).as("a") }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`c`", "fields" -> "`a`, `b`"), context = ExpectedContext( fragment = "$", @@ -2575,7 +2575,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkRuntimeException] { booleanDf.select(assert_true($"cond", lit(null.asInstanceOf[String]))).collect() }, - errorClass = "USER_RAISED_EXCEPTION", + condition = "USER_RAISED_EXCEPTION", parameters = Map("errorMessage" -> "null")) val nullDf = Seq(("first row", None), ("second row", Some(true))).toDF("n", "cond") @@ -2587,7 +2587,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkRuntimeException] { nullDf.select(assert_true($"cond", $"n")).collect() }, - errorClass = "USER_RAISED_EXCEPTION", + condition = "USER_RAISED_EXCEPTION", parameters = Map("errorMessage" -> "first row")) // assert_true(condition) @@ -2607,14 +2607,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkRuntimeException] { strDf.select(raise_error(lit(null.asInstanceOf[String]))).collect() }, - errorClass = "USER_RAISED_EXCEPTION", + condition = "USER_RAISED_EXCEPTION", parameters = Map("errorMessage" -> "null")) checkError( exception = intercept[SparkRuntimeException] { strDf.select(raise_error($"a")).collect() }, - errorClass = "USER_RAISED_EXCEPTION", + condition = "USER_RAISED_EXCEPTION", parameters = Map("errorMessage" -> "hello")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index bb3c00d238ca6..6589282fd3a51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -51,7 +51,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq("1").toDS().select(from_csv($"value", lit("ARRAY"), Map[String, String]().asJava)) }, - errorClass = "INVALID_SCHEMA.NON_STRUCT_TYPE", + condition = "INVALID_SCHEMA.NON_STRUCT_TYPE", parameters = Map( "inputSchema" -> "\"ARRAY\"", "dataType" -> "\"ARRAY\"" @@ -63,7 +63,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq("1").toDF("csv").selectExpr(s"from_csv(csv, 'ARRAY')") }, - errorClass = "INVALID_SCHEMA.NON_STRUCT_TYPE", + condition = "INVALID_SCHEMA.NON_STRUCT_TYPE", parameters = Map( "inputSchema" -> "\"ARRAY\"", "dataType" -> "\"ARRAY\"" @@ -109,7 +109,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkUpgradeException] { df2.collect() }, - errorClass = "INCONSISTENT_BEHAVIOR_CROSS_VERSION.PARSE_DATETIME_BY_NEW_PARSER", + condition = "INCONSISTENT_BEHAVIOR_CROSS_VERSION.PARSE_DATETIME_BY_NEW_PARSER", parameters = Map( "datetime" -> "'2013-111-11 12:13:14'", "config" -> "\"spark.sql.legacy.timeParserPolicy\"")) @@ -184,7 +184,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkUnsupportedOperationException] { df.select(from_csv(to_csv($"value"), schema, options)).collect() }, - errorClass = "UNSUPPORTED_DATATYPE", + condition = "UNSUPPORTED_DATATYPE", parameters = Map("typeName" -> toSQLType(valueType)) ) } @@ -343,7 +343,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkException] { df.select(from_csv($"value", schema, Map("mode" -> "FAILFAST"))).collect() }, - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map("badRecord" -> "[null,null,\"]", "failFastMode" -> "FAILFAST") ) @@ -351,7 +351,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(from_csv($"value", schema, Map("mode" -> "DROPMALFORMED"))).collect() }, - errorClass = "_LEGACY_ERROR_TEMP_1099", + condition = "_LEGACY_ERROR_TEMP_1099", parameters = Map( "funcName" -> "from_csv", "mode" -> "DROPMALFORMED", @@ -433,7 +433,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { Seq(("1", "i int")).toDF("csv", "schema") .select(from_csv($"csv", $"schema", options)).collect() }, - errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", + condition = "INVALID_SCHEMA.NON_STRING_LITERAL", parameters = Map("inputSchema" -> "\"schema\""), context = ExpectedContext(fragment = "from_csv", getCurrentClassCallSitePattern) ) @@ -442,7 +442,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq("1").toDF("csv").select(from_csv($"csv", lit(1), options)).collect() }, - errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", + condition = "INVALID_SCHEMA.NON_STRING_LITERAL", parameters = Map("inputSchema" -> "\"1\""), context = ExpectedContext(fragment = "from_csv", getCurrentClassCallSitePattern) ) @@ -493,14 +493,14 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkException] { df.selectExpr("parsed.a").collect() }, - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map("badRecord" -> "[1,null]", "failFastMode" -> "FAILFAST")) checkError( exception = intercept[SparkException] { df.selectExpr("parsed.b").collect() }, - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map("badRecord" -> "[1,null]", "failFastMode" -> "FAILFAST")) } } @@ -753,7 +753,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(to_csv($"value")).collect() }, - errorClass = "DATATYPE_MISMATCH.UNSUPPORTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNSUPPORTED_INPUT_TYPE", parameters = Map( "functionName" -> "`to_csv`", "dataType" -> "\"STRUCT\"", @@ -765,7 +765,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkUnsupportedOperationException] { df.select(from_csv(lit("data"), valueSchema, Map.empty[String, String])).collect() }, - errorClass = "UNSUPPORTED_DATATYPE", + condition = "UNSUPPORTED_DATATYPE", parameters = Map("typeName" -> "\"VARIANT\"") ) } @@ -776,7 +776,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(to_csv($"value")).collect() }, - errorClass = "DATATYPE_MISMATCH.UNSUPPORTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNSUPPORTED_INPUT_TYPE", parameters = Map( "functionName" -> "`to_csv`", "dataType" -> "\"INT\"", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 66b1883b91d5f..e80c3b23a7db3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -645,7 +645,7 @@ class DataFrameAggregateSuite extends QueryTest } checkError( exception = error, - errorClass = "DATATYPE_MISMATCH.UNSUPPORTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNSUPPORTED_INPUT_TYPE", parameters = Map( "functionName" -> "`collect_set`", "dataType" -> "\"MAP\"", @@ -725,7 +725,7 @@ class DataFrameAggregateSuite extends QueryTest exception = intercept[AnalysisException] { testData.groupBy(sum($"key")).count() }, - errorClass = "GROUP_BY_AGGREGATE", + condition = "GROUP_BY_AGGREGATE", parameters = Map("sqlExpr" -> "sum(key)"), context = ExpectedContext(fragment = "sum", callSitePattern = getCurrentClassCallSitePattern) ) @@ -985,7 +985,7 @@ class DataFrameAggregateSuite extends QueryTest } checkError( exception = error, - errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + condition = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", sqlState = None, parameters = Map( "functionName" -> "`max_by`", @@ -1055,7 +1055,7 @@ class DataFrameAggregateSuite extends QueryTest } checkError( exception = error, - errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + condition = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", sqlState = None, parameters = Map( "functionName" -> "`min_by`", @@ -1186,7 +1186,7 @@ class DataFrameAggregateSuite extends QueryTest exception = intercept[AnalysisException] { sql("SELECT COUNT_IF(x) FROM tempView") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = None, parameters = Map( "sqlExpr" -> "\"count_if(x)\"", @@ -1350,7 +1350,7 @@ class DataFrameAggregateSuite extends QueryTest exception = intercept[AnalysisException] { Seq(Tuple1(Seq(1))).toDF("col").groupBy(struct($"col.a")).count() }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"col[a]\"", "paramIndex" -> "second", @@ -1924,7 +1924,7 @@ class DataFrameAggregateSuite extends QueryTest ) .collect() }, - errorClass = "HLL_INVALID_LG_K", + condition = "HLL_INVALID_LG_K", parameters = Map( "function" -> "`hll_sketch_agg`", "min" -> "4", @@ -1940,7 +1940,7 @@ class DataFrameAggregateSuite extends QueryTest ) .collect() }, - errorClass = "HLL_INVALID_LG_K", + condition = "HLL_INVALID_LG_K", parameters = Map( "function" -> "`hll_sketch_agg`", "min" -> "4", @@ -1963,7 +1963,7 @@ class DataFrameAggregateSuite extends QueryTest .withColumn("union", hll_union("hllsketch_left", "hllsketch_right")) .collect() }, - errorClass = "HLL_UNION_DIFFERENT_LG_K", + condition = "HLL_UNION_DIFFERENT_LG_K", parameters = Map( "left" -> "12", "right" -> "20", @@ -1986,7 +1986,7 @@ class DataFrameAggregateSuite extends QueryTest ) .collect() }, - errorClass = "HLL_UNION_DIFFERENT_LG_K", + condition = "HLL_UNION_DIFFERENT_LG_K", parameters = Map( "left" -> "12", "right" -> "20", @@ -2007,7 +2007,7 @@ class DataFrameAggregateSuite extends QueryTest |""".stripMargin) checkAnswer(res, Nil) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"hll_sketch_agg(value, text)\"", "paramIndex" -> "second", @@ -2036,7 +2036,7 @@ class DataFrameAggregateSuite extends QueryTest |""".stripMargin) checkAnswer(res, Nil) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"hll_union_agg(sketch, Hll_4)\"", "paramIndex" -> "second", @@ -2078,7 +2078,7 @@ class DataFrameAggregateSuite extends QueryTest | cte1 join cte2 on cte1.id = cte2.id |""".stripMargin).collect() }, - errorClass = "HLL_UNION_DIFFERENT_LG_K", + condition = "HLL_UNION_DIFFERENT_LG_K", parameters = Map( "left" -> "12", "right" -> "20", @@ -2114,7 +2114,7 @@ class DataFrameAggregateSuite extends QueryTest |group by 1 |""".stripMargin).collect() }, - errorClass = "HLL_UNION_DIFFERENT_LG_K", + condition = "HLL_UNION_DIFFERENT_LG_K", parameters = Map( "left" -> "12", "right" -> "20", @@ -2490,6 +2490,27 @@ class DataFrameAggregateSuite extends QueryTest }) } } + + test("SPARK-49261: Literals in grouping expressions shouldn't result in unresolved aggregation") { + val data = Seq((1, 1.001d, 2), (2, 3.001d, 4), (2, 3.001, 4)).toDF("a", "b", "c") + withTempView("v1") { + data.createOrReplaceTempView("v1") + val df = + sql("""SELECT + | ROUND(SUM(b), 6) AS sum1, + | COUNT(DISTINCT a) AS count1, + | COUNT(DISTINCT c) AS count2 + |FROM ( + | SELECT + | 6 AS gb, + | * + | FROM v1 + |) + |GROUP BY a, gb + |""".stripMargin) + checkAnswer(df, Row(1.001d, 1, 1) :: Row(6.002d, 1, 1) :: Nil) + } + } } case class B(c: Option[Double]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala index a03f083123558..5f0ae918524e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala @@ -108,7 +108,7 @@ class DataFrameAsOfJoinSuite extends QueryTest joinType = "inner", tolerance = df1.col("b"), allowExactMatches = true, direction = "backward") }, - errorClass = "AS_OF_JOIN.TOLERANCE_IS_UNFOLDABLE", + condition = "AS_OF_JOIN.TOLERANCE_IS_UNFOLDABLE", parameters = Map.empty) } @@ -120,7 +120,7 @@ class DataFrameAsOfJoinSuite extends QueryTest joinType = "inner", tolerance = lit(-1), allowExactMatches = true, direction = "backward") }, - errorClass = "AS_OF_JOIN.TOLERANCE_IS_NON_NEGATIVE", + condition = "AS_OF_JOIN.TOLERANCE_IS_NON_NEGATIVE", parameters = Map.empty) } @@ -133,7 +133,7 @@ class DataFrameAsOfJoinSuite extends QueryTest joinType = "inner", tolerance = lit(-1), allowExactMatches = true, direction = direction) }, - errorClass = "AS_OF_JOIN.UNSUPPORTED_DIRECTION", + condition = "AS_OF_JOIN.UNSUPPORTED_DIRECTION", sqlState = "42604", parameters = Map( "direction" -> direction, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 739cef035c38c..0842b92e5d53c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql -import java.lang.reflect.Modifier import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import scala.reflect.runtime.universe.runtimeMirror import scala.util.Random import org.apache.spark.{QueryContextType, SPARK_DOC_ROOT, SparkException, SparkRuntimeException} @@ -45,7 +45,7 @@ import org.apache.spark.tags.ExtendedSQLTest class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { import testImplicits._ - test("DataFrame function and SQL functon parity") { + test("DataFrame function and SQL function parity") { // This test compares the available list of DataFrame functions in // org.apache.spark.sql.functions with the SQL function registry. This attempts to verify that // the DataFrame functions are a subset of the functions in the SQL function registry (subject @@ -82,7 +82,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "bucket", "days", "hours", "months", "years", // Datasource v2 partition transformations "product", // Discussed in https://github.com/apache/spark/pull/30745 "unwrap_udt", - "collect_top_k", "timestamp_add", "timestamp_diff" ) @@ -92,10 +91,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { val word_pattern = """\w*""" // Set of DataFrame functions in org.apache.spark.sql.functions - val dataFrameFunctions = functions.getClass - .getDeclaredMethods - .filter(m => Modifier.isPublic(m.getModifiers)) - .map(_.getName) + val dataFrameFunctions = runtimeMirror(getClass.getClassLoader) + .reflect(functions) + .symbol + .typeSignature + .decls + .filter(s => s.isMethod && s.isPublic) + .map(_.name.toString) .toSet .filter(_.matches(word_pattern)) .diff(excludedDataFrameFunctions) @@ -166,7 +168,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df4.select(map_from_arrays($"k", $"v")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"map_from_arrays(k, v)\"", "paramIndex" -> "first", @@ -185,7 +187,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkRuntimeException] { df5.select(map_from_arrays($"k", $"v")).collect() }, - errorClass = "NULL_MAP_KEY", + condition = "NULL_MAP_KEY", parameters = Map.empty ) @@ -344,7 +346,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { var expr = nullifzero(map(lit(1), lit("a"))) checkError( intercept[AnalysisException](df.select(expr)), - errorClass = "DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES", parameters = Map( "left" -> "\"MAP\"", "right" -> "\"INT\"", @@ -360,7 +362,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { expr = nullifzero(array(lit(1), lit(2))) checkError( intercept[AnalysisException](df.select(expr)), - errorClass = "DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES", parameters = Map( "left" -> "\"ARRAY\"", "right" -> "\"INT\"", @@ -376,7 +378,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { expr = nullifzero(Literal.create(20201231, DateType)) checkError( intercept[AnalysisException](df.select(expr)), - errorClass = "DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES", parameters = Map( "left" -> "\"DATE\"", "right" -> "\"INT\"", @@ -422,7 +424,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { var expr = zeroifnull(map(lit(1), lit("a"))) checkError( intercept[AnalysisException](df.select(expr)), - errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", parameters = Map( "functionName" -> "`coalesce`", "dataType" -> "(\"MAP\" or \"INT\")", @@ -438,7 +440,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { expr = zeroifnull(array(lit(1), lit(2))) checkError( intercept[AnalysisException](df.select(expr)), - errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", parameters = Map( "functionName" -> "`coalesce`", "dataType" -> "(\"ARRAY\" or \"INT\")", @@ -454,7 +456,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { expr = zeroifnull(Literal.create(20201231, DateType)) checkError( intercept[AnalysisException](df.select(expr)), - errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", parameters = Map( "functionName" -> "`coalesce`", "dataType" -> "(\"DATE\" or \"INT\")", @@ -886,7 +888,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(array_sort(col("a"), (x, y) => x - y)) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> """"array_sort\(a, lambdafunction\(`-`\(x_\d+, y_\d+\), x_\d+, y_\d+\)\)"""", "paramIndex" -> "first", @@ -953,7 +955,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkError( exception = error, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"sort_array(a, true)\"", "paramIndex" -> "first", @@ -964,6 +966,36 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { queryContext = Array(ExpectedContext("", "", 0, 12, "sort_array(a)")) ) + val df4 = Seq((Array[Int](2, 1, 3), true), (Array.empty[Int], false)).toDF("a", "b") + checkError( + exception = intercept[AnalysisException] { + df4.selectExpr("sort_array(a, b)").collect() + }, + condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + sqlState = "42K09", + parameters = Map( + "inputName" -> "`ascendingOrder`", + "inputType" -> "\"BOOLEAN\"", + "inputExpr" -> "\"b\"", + "sqlExpr" -> "\"sort_array(a, b)\""), + context = ExpectedContext(fragment = "sort_array(a, b)", start = 0, stop = 15) + ) + + checkError( + exception = intercept[AnalysisException] { + df4.selectExpr("sort_array(a, 'A')").collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> "\"sort_array(a, A)\"", + "paramIndex" -> "second", + "inputSql" -> "\"A\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "\"BOOLEAN\""), + context = ExpectedContext(fragment = "sort_array(a, 'A')", start = 0, stop = 17) + ) + checkAnswer( df.select(array_sort($"a"), array_sort($"b")), Seq( @@ -989,7 +1021,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df3.selectExpr("array_sort(a)").collect() }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = None, parameters = Map( "sqlExpr" -> "\"array_sort(a, lambdafunction((IF(((left IS NULL) AND (right IS NULL)), 0, (IF((left IS NULL), 1, (IF((right IS NULL), -1, (IF((left < right), -1, (IF((left > right), 1, 0)))))))))), left, right))\"", @@ -1302,7 +1334,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("map_contains_key(a, null)").collect() }, - errorClass = "DATATYPE_MISMATCH.NULL_TYPE", + condition = "DATATYPE_MISMATCH.NULL_TYPE", parameters = Map( "sqlExpr" -> "\"map_contains_key(a, NULL)\"", "functionName" -> "`map_contains_key`"), @@ -1379,7 +1411,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df2.selectExpr("map_concat(map1, map2)").collect() }, - errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", sqlState = None, parameters = Map( "sqlExpr" -> "\"map_concat(map1, map2)\"", @@ -1395,7 +1427,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df2.select(map_concat($"map1", $"map2")).collect() }, - errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", sqlState = None, parameters = Map( "sqlExpr" -> "\"map_concat(map1, map2)\"", @@ -1411,7 +1443,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df2.selectExpr("map_concat(map1, 12)").collect() }, - errorClass = "DATATYPE_MISMATCH.MAP_CONCAT_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.MAP_CONCAT_DIFF_TYPES", sqlState = None, parameters = Map( "sqlExpr" -> "\"map_concat(map1, 12)\"", @@ -1427,7 +1459,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df2.select(map_concat($"map1", lit(12))).collect() }, - errorClass = "DATATYPE_MISMATCH.MAP_CONCAT_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.MAP_CONCAT_DIFF_TYPES", sqlState = None, parameters = Map( "sqlExpr" -> "\"map_concat(map1, 12)\"", @@ -1498,7 +1530,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { wrongTypeDF.select(map_from_entries($"a")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"map_from_entries(a)\"", "paramIndex" -> "first", @@ -1542,7 +1574,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(array_contains(df("a"), null)) }, - errorClass = "DATATYPE_MISMATCH.NULL_TYPE", + condition = "DATATYPE_MISMATCH.NULL_TYPE", parameters = Map( "sqlExpr" -> "\"array_contains(a, NULL)\"", "functionName" -> "`array_contains`" @@ -1556,7 +1588,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("array_contains(a, null)") }, - errorClass = "DATATYPE_MISMATCH.NULL_TYPE", + condition = "DATATYPE_MISMATCH.NULL_TYPE", parameters = Map( "sqlExpr" -> "\"array_contains(a, NULL)\"", "functionName" -> "`array_contains`" @@ -1567,7 +1599,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("array_contains(null, 1)") }, - errorClass = "DATATYPE_MISMATCH.NULL_TYPE", + condition = "DATATYPE_MISMATCH.NULL_TYPE", parameters = Map( "sqlExpr" -> "\"array_contains(NULL, 1)\"", "functionName" -> "`array_contains`" @@ -1623,7 +1655,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { OneRowRelation().selectExpr("array_contains(array(1), 'foo')") }, - errorClass = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_contains(array(1), foo)\"", "functionName" -> "`array_contains`", @@ -1639,7 +1671,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { OneRowRelation().selectExpr("array_contains('a string', 'foo')") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"array_contains(a string, foo)\"", "paramIndex" -> "first", @@ -1688,7 +1720,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("select arrays_overlap(array(1, 2, 3), array('a', 'b', 'c'))") }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"arrays_overlap(array(1, 2, 3), array(a, b, c))\"", "functionName" -> "`arrays_overlap`", @@ -1704,7 +1736,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("select arrays_overlap(null, null)") }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"arrays_overlap(NULL, NULL)\"", "functionName" -> "`arrays_overlap`", @@ -1719,7 +1751,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("select arrays_overlap(map(1, 2), map(3, 4))") }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"arrays_overlap(map(1, 2), map(3, 4))\"", "functionName" -> "`arrays_overlap`", @@ -1794,7 +1826,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { idf.selectExpr("array_join(x, 1)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"array_join(x, 1)\"", "paramIndex" -> "second", @@ -1808,7 +1840,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { idf.selectExpr("array_join(x, ', ', 1)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"array_join(x, , , 1)\"", "paramIndex" -> "third", @@ -1924,7 +1956,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq((true, false)).toDF().selectExpr("sequence(_1, _2)") }, - errorClass = "DATATYPE_MISMATCH.SEQUENCE_WRONG_INPUT_TYPES", + condition = "DATATYPE_MISMATCH.SEQUENCE_WRONG_INPUT_TYPES", parameters = Map( "sqlExpr" -> "\"sequence(_1, _2)\"", "functionName" -> "`sequence`", @@ -1938,7 +1970,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq((true, false, 42)).toDF().selectExpr("sequence(_1, _2, _3)") }, - errorClass = "DATATYPE_MISMATCH.SEQUENCE_WRONG_INPUT_TYPES", + condition = "DATATYPE_MISMATCH.SEQUENCE_WRONG_INPUT_TYPES", parameters = Map( "sqlExpr" -> "\"sequence(_1, _2, _3)\"", "functionName" -> "`sequence`", @@ -1952,7 +1984,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq((1, 2, 0.5)).toDF().selectExpr("sequence(_1, _2, _3)") }, - errorClass = "DATATYPE_MISMATCH.SEQUENCE_WRONG_INPUT_TYPES", + condition = "DATATYPE_MISMATCH.SEQUENCE_WRONG_INPUT_TYPES", parameters = Map( "sqlExpr" -> "\"sequence(_1, _2, _3)\"", "functionName" -> "`sequence`", @@ -2068,7 +2100,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("select reverse(struct(1, 'a'))") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"reverse(struct(1, a))\"", "paramIndex" -> "first", @@ -2083,7 +2115,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("select reverse(map(1, 'a'))") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"reverse(map(1, a))\"", "paramIndex" -> "first", @@ -2169,7 +2201,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq((null, "a")).toDF().selectExpr("array_position(_1, _2)") }, - errorClass = "DATATYPE_MISMATCH.NULL_TYPE", + condition = "DATATYPE_MISMATCH.NULL_TYPE", parameters = Map( "sqlExpr" -> "\"array_position(_1, _2)\"", "functionName" -> "`array_position`" @@ -2181,7 +2213,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq(("a string element", null)).toDF().selectExpr("array_position(_1, _2)") }, - errorClass = "DATATYPE_MISMATCH.NULL_TYPE", + condition = "DATATYPE_MISMATCH.NULL_TYPE", parameters = Map( "sqlExpr" -> "\"array_position(_1, _2)\"", "functionName" -> "`array_position`" @@ -2193,7 +2225,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, _2)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"array_position(_1, _2)\"", "paramIndex" -> "first", @@ -2208,7 +2240,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { OneRowRelation().selectExpr("array_position(array(1), '1')") }, - errorClass = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_position(array(1), 1)\"", "functionName" -> "`array_position`", @@ -2281,7 +2313,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq(("a string element", 1)).toDF().selectExpr("element_at(_1, _2)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"element_at(_1, _2)\"", "paramIndex" -> "first", @@ -2311,7 +2343,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { OneRowRelation().selectExpr("element_at(array('a', 'b'), 1L)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"element_at(array(a, b), 1)\"", "paramIndex" -> "second", @@ -2358,7 +2390,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), '1')") }, - errorClass = "DATATYPE_MISMATCH.MAP_FUNCTION_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.MAP_FUNCTION_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"element_at(map(1, a, 2, b), 1)\"", "functionName" -> "`element_at`", @@ -2440,7 +2472,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df6.select(array_union($"a", $"b")) }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_union(a, b)\"", "functionName" -> "`array_union`", @@ -2456,7 +2488,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df6.selectExpr("array_union(a, b)") }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_union(a, b)\"", "functionName" -> "`array_union`", @@ -2475,7 +2507,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df7.select(array_union($"a", $"b")) }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_union(a, b)\"", "functionName" -> "`array_union`", @@ -2489,7 +2521,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df7.selectExpr("array_union(a, b)") }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_union(a, b)\"", "functionName" -> "`array_union`", @@ -2508,7 +2540,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df8.select(array_union($"a", $"b")) }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_union(a, b)\"", "functionName" -> "`array_union`", @@ -2522,7 +2554,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df8.selectExpr("array_union(a, b)") }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_union(a, b)\"", "functionName" -> "`array_union`", @@ -2609,7 +2641,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("concat(i1, i2, null)") }, - errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"concat(i1, i2, NULL)\"", "functionName" -> "`concat`", @@ -2622,7 +2654,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("concat(i1, array(i1, i2))") }, - errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"concat(i1, array(i1, i2))\"", "functionName" -> "`concat`", @@ -2635,7 +2667,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("concat(map(1, 2), map(3, 4))") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"concat(map(1, 2), map(3, 4))\"", "paramIndex" -> "first", @@ -2746,7 +2778,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { oneRowDF.select(flatten($"arr")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"flatten(arr)\"", "paramIndex" -> "first", @@ -2761,7 +2793,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { oneRowDF.select(flatten($"i")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"flatten(i)\"", "paramIndex" -> "first", @@ -2776,7 +2808,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { oneRowDF.select(flatten($"s")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"flatten(s)\"", "paramIndex" -> "first", @@ -2791,7 +2823,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { oneRowDF.selectExpr("flatten(null)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"flatten(NULL)\"", "paramIndex" -> "first", @@ -2887,7 +2919,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { invalidTypeDF.select(array_repeat($"a", $"b")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"array_repeat(a, b)\"", "paramIndex" -> "second", @@ -2902,7 +2934,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { invalidTypeDF.select(array_repeat($"a", lit("1"))) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"array_repeat(a, 1)\"", "paramIndex" -> "second", @@ -2917,7 +2949,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { invalidTypeDF.selectExpr("array_repeat(a, 1.0)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"array_repeat(a, 1.0)\"", "paramIndex" -> "second", @@ -2968,7 +3000,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq(("a string element", "a")).toDF().selectExpr("array_prepend(_1, _2)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "paramIndex" -> "first", "sqlExpr" -> "\"array_prepend(_1, _2)\"", @@ -2980,7 +3012,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { OneRowRelation().selectExpr("array_prepend(array(1, 2), '1')") }, - errorClass = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_prepend(array(1, 2), 1)\"", "functionName" -> "`array_prepend`", @@ -3084,7 +3116,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq(("a string element", "a")).toDF().selectExpr("array_remove(_1, _2)") }, - errorClass = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_remove(_1, _2)\"", "functionName" -> "`array_remove`", @@ -3099,7 +3131,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { OneRowRelation().selectExpr("array_remove(array(1, 2), '1')") }, - errorClass = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_remove(array(1, 2), 1)\"", "functionName" -> "`array_remove`", @@ -3232,7 +3264,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df6.select(array_except($"a", $"b")) }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_except(a, b)\"", "functionName" -> "`array_except`", @@ -3247,7 +3279,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df6.selectExpr("array_except(a, b)") }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_except(a, b)\"", "functionName" -> "`array_except`", @@ -3262,7 +3294,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df7.select(array_except($"a", $"b")) }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_except(a, b)\"", "functionName" -> "`array_except`", @@ -3277,7 +3309,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df7.selectExpr("array_except(a, b)") }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_except(a, b)\"", "functionName" -> "`array_except`", @@ -3292,7 +3324,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df8.select(array_except($"a", $"b")) }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_except(a, b)\"", "functionName" -> "`array_except`", @@ -3307,7 +3339,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df8.selectExpr("array_except(a, b)") }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_except(a, b)\"", "functionName" -> "`array_except`", @@ -3322,7 +3354,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df9.select(array_except($"a", $"b")) }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_except(a, b)\"", "functionName" -> "`array_except`", @@ -3337,7 +3369,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df9.selectExpr("array_except(a, b)") }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_except(a, b)\"", "functionName" -> "`array_except`", @@ -3393,7 +3425,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df6.select(array_intersect($"a", $"b")) }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_intersect(a, b)\"", "functionName" -> "`array_intersect`", @@ -3408,7 +3440,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df6.selectExpr("array_intersect(a, b)") }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_intersect(a, b)\"", "functionName" -> "`array_intersect`", @@ -3424,7 +3456,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df7.select(array_intersect($"a", $"b")) }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_intersect(a, b)\"", "functionName" -> "`array_intersect`", @@ -3439,7 +3471,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df7.selectExpr("array_intersect(a, b)") }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_intersect(a, b)\"", "functionName" -> "`array_intersect`", @@ -3455,7 +3487,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df8.select(array_intersect($"a", $"b")) }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_intersect(a, b)\"", "functionName" -> "`array_intersect`", @@ -3472,7 +3504,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df8.selectExpr("array_intersect(a, b)") }, - errorClass = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array_intersect(a, b)\"", "functionName" -> "`array_intersect`", @@ -3506,7 +3538,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkRuntimeException] { df5.selectExpr("array_insert(a, b, c)").show() }, - errorClass = "INVALID_INDEX_OF_ZERO", + condition = "INVALID_INDEX_OF_ZERO", parameters = Map.empty, context = ExpectedContext( fragment = "array_insert(a, b, c)", @@ -3748,7 +3780,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("transform(s, (x, y, z) -> x + y + z)") }, - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", + condition = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", parameters = Map("expectedNumArgs" -> "3", "actualNumArgs" -> "1"), context = ExpectedContext( fragment = "(x, y, z) -> x + y + z", @@ -3758,7 +3790,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkError( exception = intercept[AnalysisException](df.selectExpr("transform(i, x -> x)")), - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = None, parameters = Map( "sqlExpr" -> "\"transform(i, lambdafunction(x, x))\"", @@ -3774,7 +3806,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkError( exception = intercept[AnalysisException](df.selectExpr("transform(a, x -> x)")), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map("objectName" -> "`a`", "proposal" -> "`i`, `s`"), context = ExpectedContext( @@ -3832,7 +3864,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("map_filter(s, (x, y, z) -> x + y + z)") }, - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", + condition = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", parameters = Map("expectedNumArgs" -> "3", "actualNumArgs" -> "2"), context = ExpectedContext( fragment = "(x, y, z) -> x + y + z", @@ -3844,7 +3876,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("map_filter(s, x -> x)") }, - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", + condition = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", parameters = Map("expectedNumArgs" -> "1", "actualNumArgs" -> "2"), context = ExpectedContext( fragment = "x -> x", @@ -3856,7 +3888,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("map_filter(i, (k, v) -> k > v)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = None, parameters = Map( "sqlExpr" -> "\"map_filter(i, lambdafunction((k > v), k, v))\"", @@ -3873,7 +3905,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(map_filter(col("i"), (k, v) => k > v)) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", matchPVals = true, parameters = Map( "sqlExpr" -> """"map_filter\(i, lambdafunction\(`>`\(x_\d+, y_\d+\), x_\d+, y_\d+\)\)"""", @@ -3887,7 +3919,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkError( exception = intercept[AnalysisException](df.selectExpr("map_filter(a, (k, v) -> k > v)")), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map("objectName" -> "`a`", "proposal" -> "`i`, `s`"), context = ExpectedContext( @@ -4029,7 +4061,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("filter(s, (x, y, z) -> x + y)") }, - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", + condition = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", parameters = Map("expectedNumArgs" -> "3", "actualNumArgs" -> "1"), context = ExpectedContext( fragment = "(x, y, z) -> x + y", @@ -4041,7 +4073,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("filter(i, x -> x)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = None, parameters = Map( "sqlExpr" -> "\"filter(i, lambdafunction(x, x))\"", @@ -4058,7 +4090,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(filter(col("i"), x => x)) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", matchPVals = true, parameters = Map( "sqlExpr" -> """"filter\(i, lambdafunction\(x_\d+, x_\d+\)\)"""", @@ -4073,7 +4105,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("filter(s, x -> x)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"filter(s, lambdafunction(namedlambdavariable(), namedlambdavariable()))\"", "paramIndex" -> "second", @@ -4089,7 +4121,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(filter(col("s"), x => x)) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"filter(s, lambdafunction(namedlambdavariable(), namedlambdavariable()))\"", "paramIndex" -> "second", @@ -4103,7 +4135,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkError( exception = intercept[AnalysisException](df.selectExpr("filter(a, x -> x)")), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map("objectName" -> "`a`", "proposal" -> "`i`, `s`"), context = ExpectedContext( @@ -4217,7 +4249,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("exists(s, (x, y) -> x + y)") }, - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", + condition = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", parameters = Map("expectedNumArgs" -> "2", "actualNumArgs" -> "1"), context = ExpectedContext( fragment = "(x, y) -> x + y", @@ -4229,7 +4261,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("exists(i, x -> x)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = None, parameters = Map( "sqlExpr" -> "\"exists(i, lambdafunction(x, x))\"", @@ -4246,7 +4278,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(exists(col("i"), x => x)) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", matchPVals = true, parameters = Map( "sqlExpr" -> """"exists\(i, lambdafunction\(x_\d+, x_\d+\)\)"""", @@ -4261,7 +4293,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("exists(s, x -> x)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"exists(s, lambdafunction(namedlambdavariable(), namedlambdavariable()))\"", "paramIndex" -> "second", @@ -4278,7 +4310,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(exists(df("s"), x => x)) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"exists(s, lambdafunction(namedlambdavariable(), namedlambdavariable()))\"", "paramIndex" -> "second", @@ -4290,7 +4322,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkError( exception = intercept[AnalysisException](df.selectExpr("exists(a, x -> x)")), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map("objectName" -> "`a`", "proposal" -> "`i`, `s`"), context = ExpectedContext( @@ -4418,7 +4450,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("forall(s, (x, y) -> x + y)") }, - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", + condition = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", parameters = Map("expectedNumArgs" -> "2", "actualNumArgs" -> "1"), context = ExpectedContext( fragment = "(x, y) -> x + y", @@ -4430,7 +4462,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("forall(i, x -> x)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = None, parameters = Map( "sqlExpr" -> "\"forall(i, lambdafunction(x, x))\"", @@ -4447,7 +4479,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(forall(col("i"), x => x)) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", matchPVals = true, parameters = Map( "sqlExpr" -> """"forall\(i, lambdafunction\(x_\d+, x_\d+\)\)"""", @@ -4462,7 +4494,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("forall(s, x -> x)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"forall(s, lambdafunction(namedlambdavariable(), namedlambdavariable()))\"", "paramIndex" -> "second", @@ -4478,7 +4510,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(forall(col("s"), x => x)) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"forall(s, lambdafunction(namedlambdavariable(), namedlambdavariable()))\"", "paramIndex" -> "second", @@ -4490,7 +4522,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkError( exception = intercept[AnalysisException](df.selectExpr("forall(a, x -> x)")), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map("objectName" -> "`a`", "proposal" -> "`i`, `s`"), context = ExpectedContext( @@ -4500,7 +4532,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkError( exception = intercept[AnalysisException](df.select(forall(col("a"), x => x))), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`a`", "proposal" -> "`i`, `s`"), queryContext = Array( ExpectedContext(fragment = "col", callSitePattern = getCurrentClassCallSitePattern))) @@ -4689,7 +4721,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr(s"$agg(s, '', x -> x)") }, - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", + condition = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", parameters = Map("expectedNumArgs" -> "1", "actualNumArgs" -> "2"), context = ExpectedContext( fragment = "x -> x", @@ -4701,7 +4733,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr(s"$agg(s, '', (acc, x) -> x, (acc, x) -> x)") }, - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", + condition = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", parameters = Map("expectedNumArgs" -> "2", "actualNumArgs" -> "1"), context = ExpectedContext( fragment = "(acc, x) -> x", @@ -4715,7 +4747,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr(s"$agg(i, 0, (acc, x) -> x)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = None, parameters = Map( "sqlExpr" -> s""""$agg(i, 0, lambdafunction(x, acc, x), lambdafunction(id, id))"""", @@ -4734,7 +4766,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(aggregate(col("i"), lit(0), (_, x) => x)) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", matchPVals = true, parameters = Map( "sqlExpr" -> """"aggregate\(i, 0, lambdafunction\(y_\d+, x_\d+, y_\d+\), lambdafunction\(x_\d+, x_\d+\)\)"""", @@ -4752,7 +4784,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr(s"$agg(s, 0, (acc, x) -> x)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> s""""$agg(s, 0, lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))"""", "paramIndex" -> "third", @@ -4772,7 +4804,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(aggregate(col("s"), lit(0), (acc, x) => x)) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> """"aggregate(s, 0, lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))"""", "paramIndex" -> "third", @@ -4788,7 +4820,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkError( exception = intercept[AnalysisException](df.selectExpr(s"$agg(a, 0, (acc, x) -> x)")), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map("objectName" -> "`a`", "proposal" -> "`i`, `s`"), context = ExpectedContext( @@ -4853,7 +4885,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("map_zip_with(mii, mis, (x, y) -> x + y)") }, - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", + condition = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", parameters = Map("expectedNumArgs" -> "2", "actualNumArgs" -> "3"), context = ExpectedContext( fragment = "(x, y) -> x + y", @@ -4865,7 +4897,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))") }, - errorClass = "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"map_zip_with(mis, mmi, lambdafunction(concat(x, y, z), x, y, z))\"", "functionName" -> "`map_zip_with`", @@ -4881,7 +4913,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(map_zip_with(df("mis"), col("mmi"), (x, y, z) => concat(x, y, z))) }, - errorClass = "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES", matchPVals = true, parameters = Map( "sqlExpr" -> """"map_zip_with\(mis, mmi, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", @@ -4896,7 +4928,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("map_zip_with(i, mis, (x, y, z) -> concat(x, y, z))") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = None, parameters = Map( "sqlExpr" -> "\"map_zip_with(i, mis, lambdafunction(concat(x, y, z), x, y, z))\"", @@ -4913,7 +4945,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(map_zip_with(col("i"), col("mis"), (x, y, z) => concat(x, y, z))) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", matchPVals = true, parameters = Map( "sqlExpr" -> """"map_zip_with\(i, mis, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", @@ -4928,7 +4960,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("map_zip_with(mis, i, (x, y, z) -> concat(x, y, z))") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = None, parameters = Map( "sqlExpr" -> "\"map_zip_with(mis, i, lambdafunction(concat(x, y, z), x, y, z))\"", @@ -4945,7 +4977,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(map_zip_with(col("mis"), col("i"), (x, y, z) => concat(x, y, z))) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", matchPVals = true, parameters = Map( "sqlExpr" -> """"map_zip_with\(mis, i, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", @@ -4960,7 +4992,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)") }, - errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + condition = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", sqlState = None, parameters = Map( "sqlExpr" -> "\"map_zip_with(mmi, mmi, lambdafunction(x, x, y, z))\"", @@ -5080,7 +5112,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { dfExample1.selectExpr("transform_keys(i, k -> k)") }, - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", + condition = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", parameters = Map("expectedNumArgs" -> "1", "actualNumArgs" -> "2"), context = ExpectedContext( fragment = "k -> k", @@ -5092,7 +5124,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { dfExample1.selectExpr("transform_keys(i, (k, v, x) -> k + 1)") }, - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", + condition = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", parameters = Map("expectedNumArgs" -> "3", "actualNumArgs" -> "2"), context = ExpectedContext( fragment = "(k, v, x) -> k + 1", @@ -5104,7 +5136,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkRuntimeException] { dfExample1.selectExpr("transform_keys(i, (k, v) -> v)").show() }, - errorClass = "NULL_MAP_KEY", + condition = "NULL_MAP_KEY", parameters = Map.empty ) @@ -5112,7 +5144,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkRuntimeException] { dfExample1.select(transform_keys(col("i"), (k, v) => v)).show() }, - errorClass = "NULL_MAP_KEY", + condition = "NULL_MAP_KEY", parameters = Map.empty ) @@ -5120,7 +5152,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { dfExample2.selectExpr("transform_keys(j, (k, v) -> k + 1)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = None, parameters = Map( "sqlExpr" -> "\"transform_keys(j, lambdafunction((k + 1), k, v))\"", @@ -5356,7 +5388,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { dfExample1.selectExpr("transform_values(i, k -> k)") }, - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", + condition = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", parameters = Map("expectedNumArgs" -> "1", "actualNumArgs" -> "2"), context = ExpectedContext( fragment = "k -> k", @@ -5368,7 +5400,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { dfExample2.selectExpr("transform_values(j, (k, v, x) -> k + 1)") }, - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", + condition = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", parameters = Map("expectedNumArgs" -> "3", "actualNumArgs" -> "2"), context = ExpectedContext( fragment = "(k, v, x) -> k + 1", @@ -5380,7 +5412,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { dfExample3.selectExpr("transform_values(x, (k, v) -> k + 1)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = None, parameters = Map( "sqlExpr" -> "\"transform_values(x, lambdafunction((k + 1), k, v))\"", @@ -5397,7 +5429,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { dfExample3.select(transform_values(col("x"), (k, v) => k + 1)) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", matchPVals = true, parameters = Map( "sqlExpr" -> @@ -5480,7 +5512,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("zip_with(a1, a2, x -> x)") }, - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", + condition = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", parameters = Map( "expectedNumArgs" -> "1", "actualNumArgs" -> "2"), @@ -5494,7 +5526,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("zip_with(a1, a2, (x, x) -> x)") }, - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.DUPLICATE_ARG_NAMES", + condition = "INVALID_LAMBDA_FUNCTION_CALL.DUPLICATE_ARG_NAMES", parameters = Map( "args" -> "`x`, `x`", "caseSensitiveConfig" -> "\"spark.sql.caseSensitive\""), @@ -5508,7 +5540,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("zip_with(a1, a2, (acc, x) -> x, (acc, x) -> x)") }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> toSQLId("zip_with"), "expectedNum" -> "3", @@ -5524,7 +5556,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("zip_with(i, a2, (acc, x) -> x)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = None, parameters = Map( "sqlExpr" -> "\"zip_with(i, a2, lambdafunction(x, acc, x))\"", @@ -5541,7 +5573,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(zip_with(df("i"), df("a2"), (_, x) => x)) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", matchPVals = true, parameters = Map( "sqlExpr" -> @@ -5556,7 +5588,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkError( exception = intercept[AnalysisException](df.selectExpr("zip_with(a1, a, (acc, x) -> x)")), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map("objectName" -> "`a`", "proposal" -> "`a1`, `a2`, `i`"), context = ExpectedContext( @@ -5609,7 +5641,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(coalesce()) }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", sqlState = None, parameters = Map( "functionName" -> "`coalesce`", @@ -5622,7 +5654,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("coalesce()") }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", sqlState = None, parameters = Map( "functionName" -> "`coalesce`", @@ -5635,7 +5667,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(hash()) }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", sqlState = None, parameters = Map( "functionName" -> "`hash`", @@ -5648,7 +5680,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("hash()") }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", sqlState = None, parameters = Map( "functionName" -> "`hash`", @@ -5661,7 +5693,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(xxhash64()) }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", sqlState = None, parameters = Map( "functionName" -> "`xxhash64`", @@ -5674,7 +5706,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("xxhash64()") }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", sqlState = None, parameters = Map( "functionName" -> "`xxhash64`", @@ -5687,7 +5719,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(greatest()) }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", sqlState = None, parameters = Map( "functionName" -> "`greatest`", @@ -5700,7 +5732,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("greatest()") }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", sqlState = None, parameters = Map( "functionName" -> "`greatest`", @@ -5713,7 +5745,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(least()) }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", sqlState = None, parameters = Map( "functionName" -> "`least`", @@ -5726,7 +5758,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("least()") }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", sqlState = None, parameters = Map( "functionName" -> "`least`", @@ -5742,7 +5774,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkRuntimeException] { df.select(map_from_arrays(concat($"k1", $"k2"), $"v")).show() }, - errorClass = "NULL_MAP_KEY", + condition = "NULL_MAP_KEY", parameters = Map.empty ) } @@ -5801,7 +5833,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(map($"m", lit(1))) }, - errorClass = "DATATYPE_MISMATCH.INVALID_MAP_KEY_TYPE", + condition = "DATATYPE_MISMATCH.INVALID_MAP_KEY_TYPE", parameters = Map( "sqlExpr" -> "\"map(m, 1)\"", "keyType" -> "\"MAP\"" @@ -5842,7 +5874,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("select from_json('{\"a\":1}', 1)") }, - errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", + condition = "INVALID_SCHEMA.NON_STRING_LITERAL", parameters = Map( "inputSchema" -> "\"1\"" ), @@ -5931,7 +5963,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { invalidDatatypeDF.select(array_compact($"a")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"array_compact(a)\"", "paramIndex" -> "first", @@ -5954,7 +5986,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df3.select(array_append(col("a"), col("b"))) }, - errorClass = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", parameters = Map( "functionName" -> "`array_append`", "dataType" -> "\"ARRAY\"", @@ -5973,7 +6005,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df3.selectExpr("array_append(a, b)") }, - errorClass = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", parameters = Map( "functionName" -> "`array_append`", "leftType" -> "\"ARRAY\"", @@ -6005,7 +6037,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df6.selectExpr("array_append(a, b)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"array_append(a, b)\"", "paramIndex" -> "first", @@ -6110,7 +6142,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkException] { df1.map(r => df2.count() * r.getInt(0)).collect() }, - errorClass = "CANNOT_INVOKE_IN_TRANSFORMATIONS", + condition = "CANNOT_INVOKE_IN_TRANSFORMATIONS", parameters = Map.empty ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index ab900e2135576..e2bdf1c732078 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -284,7 +284,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { joined_df.na.fill("", cols = Seq("f2")) }, - errorClass = "AMBIGUOUS_REFERENCE", + condition = "AMBIGUOUS_REFERENCE", parameters = Map( "name" -> "`f2`", "referenceNames" -> "[`f2`, `f2`]" @@ -304,7 +304,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.na.drop("any", Seq("*")) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`*`", "proposal" -> "`name`, `age`, `height`") ) } @@ -411,7 +411,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.na.fill("hello", Seq("col2")) }, - errorClass = "AMBIGUOUS_REFERENCE", + condition = "AMBIGUOUS_REFERENCE", parameters = Map( "name" -> "`col2`", "referenceNames" -> "[`col2`, `col2`]" @@ -434,7 +434,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.na.drop("any", Seq("col2")) }, - errorClass = "AMBIGUOUS_REFERENCE", + condition = "AMBIGUOUS_REFERENCE", parameters = Map( "name" -> "`col2`", "referenceNames" -> "[`col2`, `col2`]" @@ -540,7 +540,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { } checkError( exception = exception, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`aa`", "proposal" -> "`Col`.`1`, `Col`.`2`") ) } @@ -551,7 +551,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkUnsupportedOperationException] { df.na.replace("c1.c1-1", Map("b1" ->"a1")) }, - errorClass = "UNSUPPORTED_FEATURE.REPLACE_NESTED_COLUMN", + condition = "UNSUPPORTED_FEATURE.REPLACE_NESTED_COLUMN", parameters = Map("colName" -> "`c1`.`c1-1`") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index b3bf9405a99f2..cf4fbe61101b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -309,7 +309,7 @@ class DataFramePivotSuite extends QueryTest with SharedSparkSession { .pivot(min($"training"), Seq("Experts")) .agg(sum($"sales.earnings")) }, - errorClass = "GROUP_BY_AGGREGATE", + condition = "GROUP_BY_AGGREGATE", parameters = Map("sqlExpr" -> "min(training)"), context = ExpectedContext(fragment = "min", callSitePattern = getCurrentClassCallSitePattern) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index 310b5a62c908a..1d7698df2f1be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, PythonUDF, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan} import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.functions.{count, explode, sum, year} +import org.apache.spark.sql.functions.{col, count, explode, sum, year} import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -404,7 +405,7 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y"))) // Test for AttachDistributedSequence - val df13 = df1.withSequenceColumn("seq") + val df13 = df1.select(distributed_sequence_id().alias("seq"), col("*")) val df14 = df13.filter($"value" === "A2") assertAmbiguousSelfJoin(df13.join(df14, df13("key1") === df14("key2"))) assertAmbiguousSelfJoin(df14.join(df13, df13("key1") === df14("key2"))) @@ -483,7 +484,7 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { df3.join(df1, year($"df1.timeStr") === year($"df3.tsStr")) ) checkError(ex, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`df1`.`timeStr`", "proposal" -> "`df3`.`timeStr`, `df1`.`tsStr`"), context = ExpectedContext(fragment = "$", getCurrentClassCallSitePattern)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index cbc39557ce4cc..5ff737d2b57cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -354,21 +354,21 @@ class DataFrameSetOperationsSuite extends QueryTest val df = spark.range(1).select(map(lit("key"), $"id").as("m")) checkError( exception = intercept[AnalysisException](df.intersect(df)), - errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE", + condition = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE", parameters = Map( "colName" -> "`m`", "dataType" -> "\"MAP\"") ) checkError( exception = intercept[AnalysisException](df.except(df)), - errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE", + condition = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE", parameters = Map( "colName" -> "`m`", "dataType" -> "\"MAP\"") ) checkError( exception = intercept[AnalysisException](df.distinct()), - errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE", + condition = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE", parameters = Map( "colName" -> "`m`", "dataType" -> "\"MAP\"")) @@ -376,7 +376,7 @@ class DataFrameSetOperationsSuite extends QueryTest df.createOrReplaceTempView("v") checkError( exception = intercept[AnalysisException](sql("SELECT DISTINCT m FROM v")), - errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE", + condition = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE", parameters = Map( "colName" -> "`m`", "dataType" -> "\"MAP\""), @@ -546,7 +546,7 @@ class DataFrameSetOperationsSuite extends QueryTest exception = intercept[AnalysisException] { df1.unionByName(df2) }, - errorClass = "NUM_COLUMNS_MISMATCH", + condition = "NUM_COLUMNS_MISMATCH", parameters = Map( "operator" -> "UNION", "firstNumColumns" -> "2", @@ -610,7 +610,7 @@ class DataFrameSetOperationsSuite extends QueryTest exception = intercept[AnalysisException] { df1.unionByName(df2) }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${c0.toLowerCase(Locale.ROOT)}`")) df1 = Seq((1, 1)).toDF("c0", "c1") df2 = Seq((1, 1)).toDF(c0, c1) @@ -618,7 +618,7 @@ class DataFrameSetOperationsSuite extends QueryTest exception = intercept[AnalysisException] { df1.unionByName(df2) }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${c0.toLowerCase(Locale.ROOT)}`")) } } @@ -1022,7 +1022,7 @@ class DataFrameSetOperationsSuite extends QueryTest exception = intercept[AnalysisException] { df1.unionByName(df2) }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`c`", "fields" -> "`a`, `b`")) // If right side of the nested struct has extra col. @@ -1032,7 +1032,7 @@ class DataFrameSetOperationsSuite extends QueryTest exception = intercept[AnalysisException] { df1.unionByName(df2) }, - errorClass = "INCOMPATIBLE_COLUMN_TYPE", + condition = "INCOMPATIBLE_COLUMN_TYPE", parameters = Map( "tableOrdinalNumber" -> "second", "columnOrdinalNumber" -> "third", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 8eee8fc37661c..2f7b072fb7ece 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -143,7 +143,7 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { dfx.stat.freqItems(Array("num")) }, - errorClass = "AMBIGUOUS_REFERENCE", + condition = "AMBIGUOUS_REFERENCE", parameters = Map( "name" -> "`num`", "referenceNames" -> "[`table1`.`num`, `table2`.`num`]" @@ -155,7 +155,7 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { dfx.stat.approxQuantile("num", Array(0.1), 0.0) }, - errorClass = "AMBIGUOUS_REFERENCE", + condition = "AMBIGUOUS_REFERENCE", parameters = Map( "name" -> "`num`", "referenceNames" -> "[`table1`.`num`, `table2`.`num`]" @@ -167,7 +167,7 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { dfx.stat.cov("num", "num") }, - errorClass = "AMBIGUOUS_REFERENCE", + condition = "AMBIGUOUS_REFERENCE", parameters = Map( "name" -> "`num`", "referenceNames" -> "[`table1`.`num`, `table2`.`num`]" @@ -177,7 +177,7 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { dfx.stat.corr("num", "num") }, - errorClass = "AMBIGUOUS_REFERENCE", + condition = "AMBIGUOUS_REFERENCE", parameters = Map( "name" -> "`num`", "referenceNames" -> "[`table1`.`num`, `table2`.`num`]" @@ -588,7 +588,7 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkIllegalArgumentException] { person2.summary("foo") }, - errorClass = "_LEGACY_ERROR_TEMP_2114", + condition = "_LEGACY_ERROR_TEMP_2114", parameters = Map("stats" -> "foo") ) @@ -596,7 +596,7 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkIllegalArgumentException] { person2.summary("foo%") }, - errorClass = "_LEGACY_ERROR_TEMP_2113", + condition = "_LEGACY_ERROR_TEMP_2113", parameters = Map("stats" -> "foo%") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b1c41033fd760..e1774cab4a0de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -29,6 +29,7 @@ import org.scalatest.matchers.should.Matchers._ import org.apache.spark.SparkException import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -224,7 +225,7 @@ class DataFrameSuite extends QueryTest exception = intercept[AnalysisException] { df.select(explode($"*")) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"explode(csv)\"", "paramIndex" -> "first", @@ -568,7 +569,7 @@ class DataFrameSuite extends QueryTest testData.toDF().withColumns(Seq("newCol1", "newCOL1"), Seq(col("key") + 1, col("key") + 2)) }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`newcol1`")) } @@ -588,7 +589,7 @@ class DataFrameSuite extends QueryTest testData.toDF().withColumns(Seq("newCol1", "newCol1"), Seq(col("key") + 1, col("key") + 2)) }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`newCol1`")) } } @@ -631,7 +632,7 @@ class DataFrameSuite extends QueryTest exception = intercept[AnalysisException] { df1.withMetadata("x1", metadata) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`x1`", "proposal" -> "`x`") ) } @@ -1116,7 +1117,7 @@ class DataFrameSuite extends QueryTest exception = intercept[org.apache.spark.sql.AnalysisException] { df(name) }, - errorClass = "_LEGACY_ERROR_TEMP_1049", + condition = "_LEGACY_ERROR_TEMP_1049", parameters = Map("name" -> name)) } @@ -1202,7 +1203,7 @@ class DataFrameSuite extends QueryTest } checkError( exception = e, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`column1`")) // multiple duplicate columns present @@ -1213,7 +1214,7 @@ class DataFrameSuite extends QueryTest } checkError( exception = f, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`column1`")) } @@ -1245,7 +1246,7 @@ class DataFrameSuite extends QueryTest exception = intercept[AnalysisException] { insertion.write.insertInto("rdd_base") }, - errorClass = "UNSUPPORTED_INSERT.RDD_BASED", + condition = "UNSUPPORTED_INSERT.RDD_BASED", parameters = Map.empty ) @@ -1256,7 +1257,7 @@ class DataFrameSuite extends QueryTest exception = intercept[AnalysisException] { insertion.write.insertInto("indirect_ds") }, - errorClass = "UNSUPPORTED_INSERT.RDD_BASED", + condition = "UNSUPPORTED_INSERT.RDD_BASED", parameters = Map.empty ) @@ -1266,7 +1267,7 @@ class DataFrameSuite extends QueryTest exception = intercept[AnalysisException] { insertion.write.insertInto("one_row") }, - errorClass = "UNSUPPORTED_INSERT.RDD_BASED", + condition = "UNSUPPORTED_INSERT.RDD_BASED", parameters = Map.empty ) } @@ -2036,7 +2037,7 @@ class DataFrameSuite extends QueryTest exception = intercept[AnalysisException] { df.groupBy($"d", $"b").as[GroupByKey, Row] }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`d`", "proposal" -> "`a`, `b`, `c`"), context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) } @@ -2316,7 +2317,8 @@ class DataFrameSuite extends QueryTest } test("SPARK-36338: DataFrame.withSequenceColumn should append unique sequence IDs") { - val ids = spark.range(10).repartition(5).withSequenceColumn("default_index") + val ids = spark.range(10).repartition(5).select( + distributed_sequence_id().alias("default_index"), col("id")) assert(ids.collect().map(_.getLong(0)).toSet === Range(0, 10).toSet) assert(ids.take(5).map(_.getLong(0)).toSet === Range(0, 5).toSet) } @@ -2548,7 +2550,7 @@ class DataFrameSuite extends QueryTest exception = intercept[ParseException] { spark.range(1).toDF("CASE").filter("CASE").collect() }, - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'CASE'", "hint" -> "")) } } @@ -2560,7 +2562,7 @@ class DataFrameSuite extends QueryTest exception = intercept[AnalysisException] { spark.range(1).createTempView("AUTHORIZATION") }, - errorClass = "_LEGACY_ERROR_TEMP_1321", + condition = "_LEGACY_ERROR_TEMP_1321", parameters = Map("viewName" -> "AUTHORIZATION")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameToSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameToSchemaSuite.scala index 160f583c983d8..f166043e4d554 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameToSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameToSchemaSuite.scala @@ -58,7 +58,7 @@ class DataFrameToSchemaSuite extends QueryTest with SharedSparkSession { val e = intercept[SparkThrowable](Seq("a" -> "b").toDF("i", "j").to(schema)) checkError( exception = e, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`non_exist`", "proposal" -> "`i`, `j`")) @@ -69,7 +69,7 @@ class DataFrameToSchemaSuite extends QueryTest with SharedSparkSession { val e = intercept[SparkThrowable](Seq("a" -> "b").toDF("i", "I").to(schema)) checkError( exception = e, - errorClass = "AMBIGUOUS_COLUMN_OR_FIELD", + condition = "AMBIGUOUS_COLUMN_OR_FIELD", parameters = Map( "name" -> "`i`", "n" -> "2")) @@ -92,7 +92,7 @@ class DataFrameToSchemaSuite extends QueryTest with SharedSparkSession { val e = intercept[SparkThrowable](data.to(schema)) checkError( exception = e, - errorClass = "NULLABLE_COLUMN_OR_FIELD", + condition = "NULLABLE_COLUMN_OR_FIELD", parameters = Map("name" -> "`i`")) } @@ -108,7 +108,7 @@ class DataFrameToSchemaSuite extends QueryTest with SharedSparkSession { val e = intercept[SparkThrowable](Seq("a" -> 1).toDF("i", "j").to(schema)) checkError( exception = e, - errorClass = "INVALID_COLUMN_OR_FIELD_DATA_TYPE", + condition = "INVALID_COLUMN_OR_FIELD_DATA_TYPE", parameters = Map( "name" -> "`i`", "type" -> "\"STRING\"", @@ -160,7 +160,7 @@ class DataFrameToSchemaSuite extends QueryTest with SharedSparkSession { } checkError( exception = e, - errorClass = "UNRESOLVED_FIELD.WITH_SUGGESTION", + condition = "UNRESOLVED_FIELD.WITH_SUGGESTION", parameters = Map( "fieldName" -> "`non_exist`", "columnPath" -> "`struct`", @@ -200,7 +200,7 @@ class DataFrameToSchemaSuite extends QueryTest with SharedSparkSession { val e = intercept[SparkThrowable](data.to(schema)) checkError( exception = e, - errorClass = "NULLABLE_COLUMN_OR_FIELD", + condition = "NULLABLE_COLUMN_OR_FIELD", parameters = Map("name" -> "`struct`.`i`")) } @@ -220,7 +220,7 @@ class DataFrameToSchemaSuite extends QueryTest with SharedSparkSession { } checkError( exception = e, - errorClass = "INVALID_COLUMN_OR_FIELD_DATA_TYPE", + condition = "INVALID_COLUMN_OR_FIELD_DATA_TYPE", parameters = Map( "name" -> "`struct`.`i`", "type" -> "\"STRING\"", @@ -284,7 +284,7 @@ class DataFrameToSchemaSuite extends QueryTest with SharedSparkSession { val e = intercept[SparkThrowable](data.to(schema)) checkError( exception = e, - errorClass = "NOT_NULL_CONSTRAINT_VIOLATION.ARRAY_ELEMENT", + condition = "NOT_NULL_CONSTRAINT_VIOLATION.ARRAY_ELEMENT", parameters = Map("columnPath" -> "`arr`")) } @@ -362,7 +362,7 @@ class DataFrameToSchemaSuite extends QueryTest with SharedSparkSession { val e = intercept[SparkThrowable](data.to(schema)) checkError( exception = e, - errorClass = "NOT_NULL_CONSTRAINT_VIOLATION.MAP_VALUE", + condition = "NOT_NULL_CONSTRAINT_VIOLATION.MAP_VALUE", parameters = Map("columnPath" -> "`map`")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTransposeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTransposeSuite.scala new file mode 100644 index 0000000000000..e6e8b6d5e5b01 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTransposeSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +class DataFrameTransposeSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + // + // Test cases: input parameter + // + + test("transpose with default index column") { + checkAnswer( + salary.transpose(), + Row("salary", 2000.0, 1000.0) :: Nil + ) + } + + test("transpose with specified index column") { + checkAnswer( + salary.transpose($"salary"), + Row("personId", 1, 0) :: Nil + ) + } + + // + // Test cases: API behavior + // + + test("enforce least common type for non-index columns") { + val df = Seq(("x", 1, 10.0), ("y", 2, 20.0)).toDF("name", "id", "value") + val transposedDf = df.transpose() + checkAnswer( + transposedDf, + Row("id", 1.0, 2.0) :: Row("value", 10.0, 20.0) :: Nil + ) + // (id,IntegerType) -> (x,DoubleType) + // (value,DoubleType) -> (y,DoubleType) + assertResult(DoubleType)(transposedDf.schema("x").dataType) + assertResult(DoubleType)(transposedDf.schema("y").dataType) + + val exception = intercept[AnalysisException] { + person.transpose() + } + assert(exception.getMessage.contains( + "[TRANSPOSE_NO_LEAST_COMMON_TYPE] Transpose requires non-index columns " + + "to share a least common type")) + } + + test("enforce ascending order based on index column values for transposed columns") { + val transposedDf = person.transpose($"name") + checkAnswer( + transposedDf, + Row("id", 1, 0) :: Row("age", 20, 30) :: Nil + ) + // mike, jim -> jim, mike + assertResult(Array("key", "jim", "mike"))(transposedDf.columns) + } + + test("enforce AtomicType Attribute for index column values") { + val exceptionAtomic = intercept[AnalysisException] { + complexData.transpose($"m") // (m,MapType(StringType,IntegerType,false)) + } + assert(exceptionAtomic.getMessage.contains( + "[TRANSPOSE_INVALID_INDEX_COLUMN] Invalid index column for TRANSPOSE because:" + + " Index column must be of atomic type, but found")) + + val exceptionAttribute = intercept[AnalysisException] { + // (s,StructType(StructField(key,IntegerType,false),StructField(value,StringType,true))) + complexData.transpose($"s.key") + } + assert(exceptionAttribute.getMessage.contains( + "[TRANSPOSE_INVALID_INDEX_COLUMN] Invalid index column for TRANSPOSE because:" + + " Index column must be an atomic attribute")) + } + + test("enforce transpose max values") { + spark.conf.set(SQLConf.DATAFRAME_TRANSPOSE_MAX_VALUES.key, 1) + val exception = intercept[AnalysisException]( + person.transpose($"name") + ) + assert(exception.getMessage.contains( + "[TRANSPOSE_EXCEED_ROW_LIMIT] Number of rows exceeds the allowed limit of")) + spark.conf.set(SQLConf.DATAFRAME_TRANSPOSE_MAX_VALUES.key, + SQLConf.DATAFRAME_TRANSPOSE_MAX_VALUES.defaultValue.get) + } + + // + // Test cases: special frame + // + + test("transpose empty frame w. column names") { + val schema = StructType(Seq( + StructField("id", IntegerType), + StructField("name", StringType) + )) + val emptyDF = spark.createDataFrame(spark.sparkContext.emptyRDD[Row], schema) + val transposedDF = emptyDF.transpose() + checkAnswer( + transposedDF, + Row("name") :: Nil + ) + assertResult(StringType)(transposedDF.schema("key").dataType) + } + + test("transpose empty frame w/o column names") { + val emptyDF = spark.emptyDataFrame + checkAnswer( + emptyDF.transpose(), + Nil + ) + } + + test("transpose frame with only index column") { + val transposedDf = tableName.transpose() + checkAnswer( + transposedDf, + Nil + ) + assertResult(Array("key", "test"))(transposedDf.columns) + } + + test("transpose frame with duplicates in index column") { + val df = Seq( + ("A", 1, 2), + ("B", 3, 4), + ("A", 5, 6) + ).toDF("id", "val1", "val2") + val transposedDf = df.transpose() + checkAnswer( + transposedDf, + Seq( + Row("val1", 1, 5, 3), + Row("val2", 2, 6, 4) + ) + ) + assertResult(Array("key", "A", "A", "B"))(transposedDf.columns) + } + + test("transpose frame with nulls in index column") { + val df = Seq( + ("A", 1, 2), + ("B", 3, 4), + (null, 5, 6) + ).toDF("id", "val1", "val2") + val transposedDf = df.transpose() + checkAnswer( + transposedDf, + Seq( + Row("val1", 1, 3), + Row("val2", 2, 4) + ) + ) + assertResult(Array("key", "A", "B"))(transposedDf.columns) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala index c03c5e878427f..d03288d7dbcdf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -186,7 +186,7 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { $"key", count("key").over( Window.partitionBy($"value").orderBy($"key").rowsBetween(2147483648L, 0)))), - errorClass = "INVALID_BOUNDARY.START", + condition = "INVALID_BOUNDARY.START", parameters = Map( "invalidValue" -> "2147483648L", "boundary" -> "`start`", @@ -200,7 +200,7 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { $"key", count("key").over( Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L)))), - errorClass = "INVALID_BOUNDARY.END", + condition = "INVALID_BOUNDARY.END", parameters = Map( "invalidValue" -> "2147483648L", "boundary" -> "`end`", @@ -226,7 +226,7 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { df.select( min("key").over(window.rangeBetween(Window.unboundedPreceding, 1))) ), - errorClass = "DATATYPE_MISMATCH.RANGE_FRAME_MULTI_ORDER", + condition = "DATATYPE_MISMATCH.RANGE_FRAME_MULTI_ORDER", parameters = Map( "orderSpec" -> """key#\d+ ASC NULLS FIRST,value#\d+ ASC NULLS FIRST""", "sqlExpr" -> (""""\(ORDER BY key ASC NULLS FIRST, value ASC NULLS FIRST RANGE """ + @@ -242,7 +242,7 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { df.select( min("key").over(window.rangeBetween(-1, Window.unboundedFollowing))) ), - errorClass = "DATATYPE_MISMATCH.RANGE_FRAME_MULTI_ORDER", + condition = "DATATYPE_MISMATCH.RANGE_FRAME_MULTI_ORDER", parameters = Map( "orderSpec" -> """key#\d+ ASC NULLS FIRST,value#\d+ ASC NULLS FIRST""", "sqlExpr" -> (""""\(ORDER BY key ASC NULLS FIRST, value ASC NULLS FIRST RANGE """ + @@ -258,7 +258,7 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { df.select( min("key").over(window.rangeBetween(-1, 1))) ), - errorClass = "DATATYPE_MISMATCH.RANGE_FRAME_MULTI_ORDER", + condition = "DATATYPE_MISMATCH.RANGE_FRAME_MULTI_ORDER", parameters = Map( "orderSpec" -> """key#\d+ ASC NULLS FIRST,value#\d+ ASC NULLS FIRST""", "sqlExpr" -> (""""\(ORDER BY key ASC NULLS FIRST, value ASC NULLS FIRST RANGE """ + @@ -287,7 +287,7 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { df.select( min("value").over(window.rangeBetween(Window.unboundedPreceding, 1))) ), - errorClass = "DATATYPE_MISMATCH.SPECIFIED_WINDOW_FRAME_UNACCEPTED_TYPE", + condition = "DATATYPE_MISMATCH.SPECIFIED_WINDOW_FRAME_UNACCEPTED_TYPE", parameters = Map( "location" -> "upper", "exprType" -> "\"STRING\"", @@ -303,7 +303,7 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { df.select( min("value").over(window.rangeBetween(-1, Window.unboundedFollowing))) ), - errorClass = "DATATYPE_MISMATCH.SPECIFIED_WINDOW_FRAME_UNACCEPTED_TYPE", + condition = "DATATYPE_MISMATCH.SPECIFIED_WINDOW_FRAME_UNACCEPTED_TYPE", parameters = Map( "location" -> "lower", "exprType" -> "\"STRING\"", @@ -319,7 +319,7 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { df.select( min("value").over(window.rangeBetween(-1, 1))) ), - errorClass = "DATATYPE_MISMATCH.SPECIFIED_WINDOW_FRAME_UNACCEPTED_TYPE", + condition = "DATATYPE_MISMATCH.SPECIFIED_WINDOW_FRAME_UNACCEPTED_TYPE", parameters = Map( "location" -> "lower", "exprType" -> "\"STRING\"", @@ -512,7 +512,7 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select($"key", count("key").over(windowSpec)).collect() }, - errorClass = "DATATYPE_MISMATCH.SPECIFIED_WINDOW_FRAME_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.SPECIFIED_WINDOW_FRAME_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"RANGE BETWEEN NULL FOLLOWING AND 2 FOLLOWING\"", "lower" -> "\"NULL\"", @@ -534,7 +534,7 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select($"key", count("key").over(windowSpec)).collect() }, - errorClass = "DATATYPE_MISMATCH.SPECIFIED_WINDOW_FRAME_WITHOUT_FOLDABLE", + condition = "DATATYPE_MISMATCH.SPECIFIED_WINDOW_FRAME_WITHOUT_FOLDABLE", parameters = Map( "sqlExpr" -> "\"RANGE BETWEEN nonfoldableliteral() FOLLOWING AND 2 FOLLOWING\"", "location" -> "lower", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index ace4d5b294a78..8a86aa10887c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -388,7 +388,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest df.select($"key", count("invalid").over())) checkError( exception = e, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`invalid`", "proposal" -> "`value`, `key`"), @@ -870,7 +870,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest lag($"value", 3, null, true).over(window), lag(concat($"value", $"key"), 1, null, true).over(window)).orderBy($"order").collect() }, - errorClass = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", parameters = Map( "sqlExpr" -> "\"lag(value, nonfoldableliteral(), NULL)\"", "inputName" -> "`offset`", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala index 2275d8c213978..b7ac6af22a204 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala @@ -146,7 +146,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo exception = intercept[AnalysisException] { spark.table("source").withColumnRenamed("data", "d").writeTo("testcat.table_name").append() }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", parameters = Map("tableName" -> "`testcat`.`table_name`", "colName" -> "`data`") ) @@ -251,7 +251,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("source").withColumnRenamed("data", "d") .writeTo("testcat.table_name").overwrite(lit(true)) }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", parameters = Map("tableName" -> "`testcat`.`table_name`", "colName" -> "`data`") ) @@ -356,7 +356,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("source").withColumnRenamed("data", "d") .writeTo("testcat.table_name").overwritePartitions() }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", parameters = Map("tableName" -> "`testcat`.`table_name`", "colName" -> "`data`") ) @@ -829,14 +829,14 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo exception = intercept[AnalysisException] { ds.write }, - errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", + condition = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", parameters = Map("methodName" -> "`write`")) checkError( exception = intercept[AnalysisException] { ds.writeTo("testcat.table_name") }, - errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", + condition = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", parameters = Map("methodName" -> "`writeTo`")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index fdb2ec30fdd2d..089ce79201dd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -345,7 +345,7 @@ class DatasetSuite extends QueryTest exception = intercept[AnalysisException] { ds.select(expr("`(_1)?+.+`").as[Int]) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map( "objectName" -> "`(_1)?+.+`", @@ -359,7 +359,7 @@ class DatasetSuite extends QueryTest exception = intercept[AnalysisException] { ds.select(expr("`(_1|_2)`").as[Int]) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map( "objectName" -> "`(_1|_2)`", @@ -373,7 +373,7 @@ class DatasetSuite extends QueryTest exception = intercept[AnalysisException] { ds.select(ds("`(_1)?+.+`")) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`(_1)?+.+`", "proposal" -> "`_1`, `_2`") ) @@ -381,7 +381,7 @@ class DatasetSuite extends QueryTest exception = intercept[AnalysisException] { ds.select(ds("`(_1|_2)`")) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`(_1|_2)`", "proposal" -> "`_1`, `_2`") ) } @@ -549,7 +549,7 @@ class DatasetSuite extends QueryTest exception = intercept[AnalysisException]( ds1.joinWith(ds2, $"a.value" === $"b.value", joinType) ), - errorClass = "INVALID_JOIN_TYPE_FOR_JOINWITH", + condition = "INVALID_JOIN_TYPE_FOR_JOINWITH", sqlState = "42613", parameters = semiErrorParameters ) @@ -611,7 +611,7 @@ class DatasetSuite extends QueryTest (g, iter) => Iterator(g, iter.mkString(", ")) } }, - errorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", + condition = "INVALID_USAGE_OF_STAR_OR_REGEX", parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups"), context = ExpectedContext(fragment = "$", getCurrentClassCallSitePattern)) } @@ -640,7 +640,7 @@ class DatasetSuite extends QueryTest (g, iter) => Iterator(g, iter.mkString(", ")) } }, - errorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", + condition = "INVALID_USAGE_OF_STAR_OR_REGEX", parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups"), context = ExpectedContext(fragment = "$", getCurrentClassCallSitePattern)) } @@ -1187,11 +1187,15 @@ class DatasetSuite extends QueryTest exception = intercept[AnalysisException] { df.as[KryoData] }, - errorClass = "DATATYPE_MISMATCH.CAST_WITHOUT_SUGGESTION", + condition = "CANNOT_UP_CAST_DATATYPE", parameters = Map( - "sqlExpr" -> "\"a\"", - "srcType" -> "\"DOUBLE\"", - "targetType" -> "\"BINARY\"")) + "expression" -> "a", + "sourceType" -> "\"DOUBLE\"", + "targetType" -> "\"BINARY\"", + "details" -> ("The type path of the target object is:\n- root class: " + + "\"org.apache.spark.sql.KryoData\"\n" + + "You can either add an explicit cast to the input data or choose a " + + "higher precision type of the field in the target object"))) } test("Java encoder") { @@ -1239,7 +1243,7 @@ class DatasetSuite extends QueryTest val ds = Seq(ClassData("a", 1)).toDS() checkError( exception = intercept[AnalysisException] (ds.as[ClassData2]), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`c`", "proposal" -> "`a`, `b`")) @@ -1429,7 +1433,7 @@ class DatasetSuite extends QueryTest dataset.createTempView("tempView")) intercept[AnalysisException](dataset.createTempView("tempView")) checkError(e, - errorClass = "TEMP_TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TEMP_TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> "`tempView`")) dataset.sparkSession.catalog.dropTempView("tempView") @@ -1440,7 +1444,7 @@ class DatasetSuite extends QueryTest val e = intercept[AnalysisException]( dataset.createTempView("test_db.tempView")) checkError(e, - errorClass = "TEMP_VIEW_NAME_TOO_MANY_NAME_PARTS", + condition = "TEMP_VIEW_NAME_TOO_MANY_NAME_PARTS", parameters = Map("actualName" -> "test_db.tempView")) } @@ -1902,19 +1906,19 @@ class DatasetSuite extends QueryTest exception = intercept[SparkUnsupportedOperationException] { Seq(CircularReferenceClassA(null)).toDS() }, - errorClass = "_LEGACY_ERROR_TEMP_2139", + condition = "_LEGACY_ERROR_TEMP_2139", parameters = Map("t" -> "org.apache.spark.sql.CircularReferenceClassA")) checkError( exception = intercept[SparkUnsupportedOperationException] { Seq(CircularReferenceClassC(null)).toDS() }, - errorClass = "_LEGACY_ERROR_TEMP_2139", + condition = "_LEGACY_ERROR_TEMP_2139", parameters = Map("t" -> "org.apache.spark.sql.CircularReferenceClassC")) checkError( exception = intercept[SparkUnsupportedOperationException] { Seq(CircularReferenceClassD(null)).toDS() }, - errorClass = "_LEGACY_ERROR_TEMP_2139", + condition = "_LEGACY_ERROR_TEMP_2139", parameters = Map("t" -> "org.apache.spark.sql.CircularReferenceClassD")) } @@ -2051,17 +2055,17 @@ class DatasetSuite extends QueryTest test("SPARK-22472: add null check for top-level primitive values") { // If the primitive values are from Option, we need to do runtime null check. val ds = Seq(Some(1), None).toDS().as[Int] - val errorClass = "NOT_NULL_ASSERT_VIOLATION" + val condition = "NOT_NULL_ASSERT_VIOLATION" val sqlState = "42000" val parameters = Map("walkedTypePath" -> "\n- root class: \"int\"\n") checkError( exception = intercept[SparkRuntimeException](ds.collect()), - errorClass = errorClass, + condition = condition, sqlState = sqlState, parameters = parameters) checkError( exception = intercept[SparkRuntimeException](ds.map(_ * 2).collect()), - errorClass = errorClass, + condition = condition, sqlState = sqlState, parameters = parameters) @@ -2071,12 +2075,12 @@ class DatasetSuite extends QueryTest val ds = spark.read.parquet(path.getCanonicalPath).as[Int] checkError( exception = intercept[SparkRuntimeException](ds.collect()), - errorClass = errorClass, + condition = condition, sqlState = sqlState, parameters = parameters) checkError( exception = intercept[SparkRuntimeException](ds.map(_ * 2).collect()), - errorClass = errorClass, + condition = condition, sqlState = sqlState, parameters = parameters) } @@ -2317,7 +2321,7 @@ class DatasetSuite extends QueryTest exception = intercept[AnalysisException] { ds(colName) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> colName, "proposal" -> "`field`.`1`, `field 2`") ) } @@ -2334,7 +2338,7 @@ class DatasetSuite extends QueryTest // has different semantics than ds.select(colName) ds.select(colName) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map( "objectName" -> s"`${colName.replace(".", "`.`")}`", @@ -2349,7 +2353,7 @@ class DatasetSuite extends QueryTest exception = intercept[AnalysisException] { Seq(0).toDF("the.id").select("the.id") }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map( "objectName" -> "`the`.`id`", @@ -2364,7 +2368,7 @@ class DatasetSuite extends QueryTest .select(map(lit("key"), lit(1)).as("map"), lit(2).as("other.column")) .select($"`map`"($"nonexisting")).show() }, - errorClass = "UNRESOLVED_MAP_KEY.WITH_SUGGESTION", + condition = "UNRESOLVED_MAP_KEY.WITH_SUGGESTION", sqlState = None, parameters = Map( "objectName" -> "`nonexisting`", @@ -2676,7 +2680,7 @@ class DatasetSuite extends QueryTest // Expression decoding error checkError( exception = exception, - errorClass = "EXPRESSION_DECODING_FAILED", + condition = "EXPRESSION_DECODING_FAILED", parameters = Map( "expressions" -> expressions.map( _.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")) @@ -2684,7 +2688,7 @@ class DatasetSuite extends QueryTest // class unsupported by map objects checkError( exception = exception.getCause.asInstanceOf[org.apache.spark.SparkRuntimeException], - errorClass = "CLASS_UNSUPPORTED_BY_MAP_OBJECTS", + condition = "CLASS_UNSUPPORTED_BY_MAP_OBJECTS", parameters = Map("cls" -> classOf[Array[Int]].getName)) } } @@ -2697,7 +2701,7 @@ class DatasetSuite extends QueryTest } checkError( exception = exception, - errorClass = "EXPRESSION_ENCODING_FAILED", + condition = "EXPRESSION_ENCODING_FAILED", parameters = Map( "expressions" -> enc.serializer.map( _.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")) @@ -2746,7 +2750,7 @@ class DatasetSuite extends QueryTest } checkError( exception, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map("objectName" -> "`a`", "proposal" -> "`value`"), context = ExpectedContext(fragment = "col", callSitePattern = callSitePattern)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala index 5e5e4d09c5274..dad69a9aab06d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala @@ -149,7 +149,7 @@ class DatasetUnpivotSuite extends QueryTest } checkError( exception = e, - errorClass = "UNPIVOT_REQUIRES_VALUE_COLUMNS", + condition = "UNPIVOT_REQUIRES_VALUE_COLUMNS", parameters = Map()) // ids expressions are not allowed when no values are given @@ -162,7 +162,7 @@ class DatasetUnpivotSuite extends QueryTest } checkError( exception = e2, - errorClass = "UNPIVOT_REQUIRES_ATTRIBUTES", + condition = "UNPIVOT_REQUIRES_ATTRIBUTES", parameters = Map( "given" -> "id", "empty" -> "value", @@ -178,7 +178,7 @@ class DatasetUnpivotSuite extends QueryTest } checkError( exception = e3, - errorClass = "UNPIVOT_REQUIRES_ATTRIBUTES", + condition = "UNPIVOT_REQUIRES_ATTRIBUTES", parameters = Map( "given" -> "id", "empty" -> "value", @@ -207,7 +207,7 @@ class DatasetUnpivotSuite extends QueryTest } checkError( exception = e, - errorClass = "UNPIVOT_REQUIRES_VALUE_COLUMNS", + condition = "UNPIVOT_REQUIRES_VALUE_COLUMNS", parameters = Map()) } @@ -315,7 +315,7 @@ class DatasetUnpivotSuite extends QueryTest } checkError( exception = e, - errorClass = "UNPIVOT_VALUE_DATA_TYPE_MISMATCH", + condition = "UNPIVOT_VALUE_DATA_TYPE_MISMATCH", parameters = Map( "types" -> ( """"BIGINT" (`long1`, `long2`), """ + @@ -371,7 +371,7 @@ class DatasetUnpivotSuite extends QueryTest } checkError( exception = e1, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`1`", "proposal" -> "`id`, `int1`, `str1`, `long1`, `str2`"), @@ -388,7 +388,7 @@ class DatasetUnpivotSuite extends QueryTest } checkError( exception = e2, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`does`", "proposal" -> "`id`, `int1`, `long1`, `str1`, `str2`"), @@ -404,7 +404,7 @@ class DatasetUnpivotSuite extends QueryTest } checkError( exception = e3, - errorClass = "UNPIVOT_VALUE_DATA_TYPE_MISMATCH", + condition = "UNPIVOT_VALUE_DATA_TYPE_MISMATCH", parameters = Map( "types" -> """"BIGINT" (`long1`), "INT" (`id`, `int1`), "STRING" (`str1`, `str2`)""" ) @@ -420,7 +420,7 @@ class DatasetUnpivotSuite extends QueryTest } checkError( exception = e4, - errorClass = "UNPIVOT_REQUIRES_VALUE_COLUMNS", + condition = "UNPIVOT_REQUIRES_VALUE_COLUMNS", parameters = Map() ) @@ -436,7 +436,7 @@ class DatasetUnpivotSuite extends QueryTest } checkError( exception = e5, - errorClass = "UNPIVOT_VALUE_DATA_TYPE_MISMATCH", + condition = "UNPIVOT_VALUE_DATA_TYPE_MISMATCH", parameters = Map( "types" -> """"BIGINT" (`long1`), "INT" (`id`, `int1`), "STRING" (`str1`, `str2`)""" ) @@ -452,7 +452,7 @@ class DatasetUnpivotSuite extends QueryTest } checkError( exception = e6, - errorClass = "UNPIVOT_REQUIRES_VALUE_COLUMNS", + condition = "UNPIVOT_REQUIRES_VALUE_COLUMNS", parameters = Map.empty ) } @@ -507,7 +507,7 @@ class DatasetUnpivotSuite extends QueryTest } checkError( exception = e, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`an`.`id`", "proposal" -> "`an.id`, `int1`, `long1`, `str.one`, `str.two`"), @@ -607,7 +607,7 @@ class DatasetUnpivotSuite extends QueryTest } checkError( exception = e, - errorClass = "UNPIVOT_REQUIRES_ATTRIBUTES", + condition = "UNPIVOT_REQUIRES_ATTRIBUTES", parameters = Map( "given" -> "value", "empty" -> "id", @@ -635,7 +635,7 @@ class DatasetUnpivotSuite extends QueryTest } checkError( exception = e2, - errorClass = "UNPIVOT_REQUIRES_ATTRIBUTES", + condition = "UNPIVOT_REQUIRES_ATTRIBUTES", parameters = Map( "given" -> "value", "empty" -> "id", @@ -661,7 +661,7 @@ class DatasetUnpivotSuite extends QueryTest } checkError( exception = e, - errorClass = "UNPIVOT_VALUE_SIZE_MISMATCH", + condition = "UNPIVOT_VALUE_SIZE_MISMATCH", parameters = Map("names" -> "2")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index b261ecfb0cee4..4cab05dfd2b9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -54,7 +54,7 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("SELECT CURDATE(1)") }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> "`curdate`", "expectedNum" -> "0", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 229677d208136..e44bd5de4f4c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -52,7 +52,7 @@ class FileBasedDataSourceSuite extends QueryTest override def beforeAll(): Unit = { super.beforeAll() - spark.conf.set(SQLConf.ORC_IMPLEMENTATION, "native") + spark.conf.set(SQLConf.ORC_IMPLEMENTATION.key, "native") } override def afterAll(): Unit = { @@ -133,7 +133,7 @@ class FileBasedDataSourceSuite extends QueryTest exception = intercept[AnalysisException] { spark.emptyDataFrame.write.format(format).save(outputPath.toString) }, - errorClass = "_LEGACY_ERROR_TEMP_1142", + condition = "_LEGACY_ERROR_TEMP_1142", parameters = Map.empty ) } @@ -150,7 +150,7 @@ class FileBasedDataSourceSuite extends QueryTest exception = intercept[AnalysisException] { df.write.format(format).save(outputPath.toString) }, - errorClass = "_LEGACY_ERROR_TEMP_1142", + condition = "_LEGACY_ERROR_TEMP_1142", parameters = Map.empty ) } @@ -250,7 +250,7 @@ class FileBasedDataSourceSuite extends QueryTest exception = intercept[SparkException] { testIgnoreMissingFiles(options) }, - errorClass = "FAILED_READ_FILE.FILE_NOT_EXIST", + condition = "FAILED_READ_FILE.FILE_NOT_EXIST", parameters = Map("path" -> ".*") ) } @@ -282,7 +282,7 @@ class FileBasedDataSourceSuite extends QueryTest exception = intercept[AnalysisException] { Seq(1).toDF().write.text(textDir) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`value`", "columnType" -> "\"INT\"", @@ -293,7 +293,7 @@ class FileBasedDataSourceSuite extends QueryTest exception = intercept[AnalysisException] { Seq(1.2).toDF().write.text(textDir) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`value`", "columnType" -> "\"DOUBLE\"", @@ -304,7 +304,7 @@ class FileBasedDataSourceSuite extends QueryTest exception = intercept[AnalysisException] { Seq(true).toDF().write.text(textDir) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`value`", "columnType" -> "\"BOOLEAN\"", @@ -315,7 +315,7 @@ class FileBasedDataSourceSuite extends QueryTest exception = intercept[AnalysisException] { Seq(1).toDF("a").selectExpr("struct(a)").write.text(textDir) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`struct(a)`", "columnType" -> "\"STRUCT\"", @@ -326,7 +326,7 @@ class FileBasedDataSourceSuite extends QueryTest exception = intercept[AnalysisException] { Seq((Map("Tesla" -> 3))).toDF("cars").write.mode("overwrite").text(textDir) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`cars`", "columnType" -> "\"MAP\"", @@ -338,7 +338,7 @@ class FileBasedDataSourceSuite extends QueryTest Seq((Array("Tesla", "Chevy", "Ford"))).toDF("brands") .write.mode("overwrite").text(textDir) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`brands`", "columnType" -> "\"ARRAY\"", @@ -352,7 +352,7 @@ class FileBasedDataSourceSuite extends QueryTest val schema = StructType(StructField("a", IntegerType, true) :: Nil) spark.read.schema(schema).text(textDir).collect() }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`a`", "columnType" -> "\"INT\"", @@ -364,7 +364,7 @@ class FileBasedDataSourceSuite extends QueryTest val schema = StructType(StructField("a", DoubleType, true) :: Nil) spark.read.schema(schema).text(textDir).collect() }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`a`", "columnType" -> "\"DOUBLE\"", @@ -376,7 +376,7 @@ class FileBasedDataSourceSuite extends QueryTest val schema = StructType(StructField("a", BooleanType, true) :: Nil) spark.read.schema(schema).text(textDir).collect() }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`a`", "columnType" -> "\"BOOLEAN\"", @@ -397,7 +397,7 @@ class FileBasedDataSourceSuite extends QueryTest exception = intercept[AnalysisException] { Seq((1, "Tesla")).toDF("a", "b").selectExpr("struct(a, b)").write.csv(csvDir) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`struct(a, b)`", "columnType" -> "\"STRUCT\"", @@ -410,7 +410,7 @@ class FileBasedDataSourceSuite extends QueryTest spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`a`", "columnType" -> "\"STRUCT\"", @@ -421,7 +421,7 @@ class FileBasedDataSourceSuite extends QueryTest exception = intercept[AnalysisException] { Seq((1, Map("Tesla" -> 3))).toDF("id", "cars").write.mode("overwrite").csv(csvDir) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`cars`", "columnType" -> "\"MAP\"", @@ -434,7 +434,7 @@ class FileBasedDataSourceSuite extends QueryTest spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`a`", "columnType" -> "\"MAP\"", @@ -446,7 +446,7 @@ class FileBasedDataSourceSuite extends QueryTest Seq((1, Array("Tesla", "Chevy", "Ford"))).toDF("id", "brands") .write.mode("overwrite").csv(csvDir) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`brands`", "columnType" -> "\"ARRAY\"", @@ -459,7 +459,7 @@ class FileBasedDataSourceSuite extends QueryTest spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`a`", "columnType" -> "\"ARRAY\"", @@ -471,7 +471,7 @@ class FileBasedDataSourceSuite extends QueryTest Seq((1, new TestUDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") .write.mode("overwrite").csv(csvDir) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`vectors`", "columnType" -> "UDT(\"ARRAY\")", @@ -484,7 +484,7 @@ class FileBasedDataSourceSuite extends QueryTest spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`a`", "columnType" -> "UDT(\"ARRAY\")", @@ -512,7 +512,7 @@ class FileBasedDataSourceSuite extends QueryTest exception = intercept[AnalysisException] { sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) }, - errorClass = "_LEGACY_ERROR_TEMP_1136", + condition = "_LEGACY_ERROR_TEMP_1136", parameters = Map.empty ) } @@ -529,7 +529,7 @@ class FileBasedDataSourceSuite extends QueryTest spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`a`", "columnType" -> "\"INTERVAL\"", @@ -542,7 +542,7 @@ class FileBasedDataSourceSuite extends QueryTest spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`a`", "columnType" -> "UDT(\"INTERVAL\")", @@ -579,7 +579,7 @@ class FileBasedDataSourceSuite extends QueryTest exception = intercept[AnalysisException] { sql("select null").write.format(format).mode("overwrite").save(tempDir) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`NULL`", "columnType" -> "\"VOID\"", @@ -592,7 +592,7 @@ class FileBasedDataSourceSuite extends QueryTest spark.udf.register("testType", () => new NullData()) sql("select testType()").write.format(format).mode("overwrite").save(tempDir) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`testType()`", "columnType" -> "UDT(\"VOID\")", @@ -607,7 +607,7 @@ class FileBasedDataSourceSuite extends QueryTest spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`a`", "columnType" -> "\"VOID\"", @@ -621,7 +621,7 @@ class FileBasedDataSourceSuite extends QueryTest spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`a`", "columnType" -> "UDT(\"VOID\")", @@ -657,14 +657,14 @@ class FileBasedDataSourceSuite extends QueryTest exception = intercept[SparkException] { sql(s"select b from $tableName").collect() }.getCause.asInstanceOf[SparkRuntimeException], - errorClass = "_LEGACY_ERROR_TEMP_2093", + condition = "_LEGACY_ERROR_TEMP_2093", parameters = Map("requiredFieldName" -> "b", "matchedOrcFields" -> "[b, B]") ) checkError( exception = intercept[SparkException] { sql(s"select B from $tableName").collect() }.getCause.asInstanceOf[SparkRuntimeException], - errorClass = "_LEGACY_ERROR_TEMP_2093", + condition = "_LEGACY_ERROR_TEMP_2093", parameters = Map("requiredFieldName" -> "b", "matchedOrcFields" -> "[b, B]") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala b/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala index 48a16f01d5749..6cd8ade41da14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala @@ -225,7 +225,7 @@ class TPCDSTables(spark: SparkSession, dsdgenDir: String, scaleFactor: Int) // datagen speed files will be truncated to maxRecordsPerFile value, so the final // result will be the same. val numRows = data.count() - val maxRecordPerFile = spark.conf.get(SQLConf.MAX_RECORDS_PER_FILE) + val maxRecordPerFile = spark.sessionState.conf.getConf(SQLConf.MAX_RECORDS_PER_FILE) if (maxRecordPerFile > 0 && numRows > maxRecordPerFile) { val numFiles = (numRows.toDouble/maxRecordPerFile).ceil.toInt diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 97a56bdea7be7..b9491a79cc3a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -59,7 +59,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("stack(1.1, 1, 2, 3)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"stack(1.1, 1, 2, 3)\"", "paramIndex" -> "first", @@ -77,7 +77,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("stack(-1, 1, 2, 3)") }, - errorClass = "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE", + condition = "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE", parameters = Map( "sqlExpr" -> "\"stack(-1, 1, 2, 3)\"", "exprName" -> "`n`", @@ -95,7 +95,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("stack(2, 1, '2.2')") }, - errorClass = "DATATYPE_MISMATCH.STACK_COLUMN_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.STACK_COLUMN_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"stack(2, 1, 2.2)\"", "columnIndex" -> "0", @@ -118,7 +118,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df2.selectExpr("stack(n, a, b, c)") }, - errorClass = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", parameters = Map( "sqlExpr" -> "\"stack(n, a, b, c)\"", "inputName" -> "`n`", @@ -136,7 +136,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df3.selectExpr("stack(2, a, b)") }, - errorClass = "DATATYPE_MISMATCH.STACK_COLUMN_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.STACK_COLUMN_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"stack(2, a, b)\"", "columnIndex" -> "0", @@ -287,7 +287,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { spark.range(2).select(inline(array())) }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"inline(array())\"", "paramIndex" -> "first", @@ -330,7 +330,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(inline(array(struct(Symbol("a")), struct(Symbol("b"))))) }, - errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array(struct(a), struct(b))\"", "functionName" -> "`array`", @@ -348,7 +348,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(inline(array(struct(Symbol("a")), struct(lit(2))))) }, - errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", parameters = Map( "sqlExpr" -> "\"array(struct(a), struct(2))\"", "functionName" -> "`array`", @@ -427,7 +427,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("select 1 + explode(array(min(c2), max(c2))) from t1 group by c1") }, - errorClass = "UNSUPPORTED_GENERATOR.NESTED_IN_EXPRESSIONS", + condition = "UNSUPPORTED_GENERATOR.NESTED_IN_EXPRESSIONS", parameters = Map( "expression" -> "\"(1 + explode(array(min(c2), max(c2))))\"")) @@ -440,7 +440,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { | posexplode(array(min(c2), max(c2))) |from t1 group by c1""".stripMargin) }, - errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", + condition = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", parameters = Map( "num" -> "2", "generators" -> ("\"explode(array(min(c2), max(c2)))\", " + @@ -453,7 +453,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("SELECT array(array(1, 2), array(3)) v").select(explode(explode($"v"))).collect() }, - errorClass = "UNSUPPORTED_GENERATOR.NESTED_IN_EXPRESSIONS", + condition = "UNSUPPORTED_GENERATOR.NESTED_IN_EXPRESSIONS", parameters = Map("expression" -> "\"explode(explode(v))\"")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index fcb937d82ba42..0f5582def82da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -597,10 +597,10 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan SQLConf.CROSS_JOINS_ENABLED.key -> "true") { assert(statisticSizeInByte(spark.table("testData2")) > - spark.conf.get[Long](SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + sqlConf.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) assert(statisticSizeInByte(spark.table("testData")) < - spark.conf.get[Long](SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + sqlConf.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 6b4be982b3ecb..7b19ad988d308 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -190,7 +190,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { nonStringDF.select(json_tuple($"a", "1")).collect() }, - errorClass = "DATATYPE_MISMATCH.NON_STRING_TYPE", + condition = "DATATYPE_MISMATCH.NON_STRING_TYPE", parameters = Map( "sqlExpr" -> "\"json_tuple(a, 1)\"", "funcName" -> "`json_tuple`" @@ -499,7 +499,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df2.selectExpr("to_json(a, named_struct('a', 1))") }, - errorClass = "INVALID_OPTIONS.NON_MAP_FUNCTION", + condition = "INVALID_OPTIONS.NON_MAP_FUNCTION", parameters = Map.empty, context = ExpectedContext( fragment = "to_json(a, named_struct('a', 1))", @@ -512,7 +512,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df2.selectExpr("to_json(a, map('a', 1))") }, - errorClass = "INVALID_OPTIONS.NON_STRING_TYPE", + condition = "INVALID_OPTIONS.NON_STRING_TYPE", parameters = Map("mapType" -> "\"MAP\""), context = ExpectedContext( fragment = "to_json(a, map('a', 1))", @@ -543,7 +543,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df3.selectExpr("from_json(value, 1)") }, - errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", + condition = "INVALID_SCHEMA.NON_STRING_LITERAL", parameters = Map("inputSchema" -> "\"1\""), context = ExpectedContext( fragment = "from_json(value, 1)", @@ -556,7 +556,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df3.selectExpr("""from_json(value, 'time InvalidType')""") }, - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", sqlState = "42601", parameters = Map( "error" -> "'InvalidType'", @@ -572,7 +572,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df3.selectExpr("from_json(value, 'time Timestamp', named_struct('a', 1))") }, - errorClass = "INVALID_OPTIONS.NON_MAP_FUNCTION", + condition = "INVALID_OPTIONS.NON_MAP_FUNCTION", parameters = Map.empty, context = ExpectedContext( fragment = "from_json(value, 'time Timestamp', named_struct('a', 1))", @@ -584,7 +584,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df3.selectExpr("from_json(value, 'time Timestamp', map('a', 1))") }, - errorClass = "INVALID_OPTIONS.NON_STRING_TYPE", + condition = "INVALID_OPTIONS.NON_STRING_TYPE", parameters = Map("mapType" -> "\"MAP\""), context = ExpectedContext( fragment = "from_json(value, 'time Timestamp', map('a', 1))", @@ -657,7 +657,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq("""{{"f": 1}: "a"}""").toDS().select(from_json($"value", schema)) }, - errorClass = "DATATYPE_MISMATCH.INVALID_JSON_MAP_KEY_TYPE", + condition = "DATATYPE_MISMATCH.INVALID_JSON_MAP_KEY_TYPE", parameters = Map( "schema" -> "\"MAP, STRING>\"", "sqlExpr" -> "\"entries\""), @@ -851,7 +851,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkException] { df.select(from_json($"value", schema, Map("mode" -> "FAILFAST"))).collect() }, - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map( "badRecord" -> "[null,null,{\"a\" 1, \"b\": 11}]", "failFastMode" -> "FAILFAST") @@ -861,7 +861,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(from_json($"value", schema, Map("mode" -> "DROPMALFORMED"))).collect() }, - errorClass = "_LEGACY_ERROR_TEMP_1099", + condition = "_LEGACY_ERROR_TEMP_1099", parameters = Map( "funcName" -> "from_json", "mode" -> "DROPMALFORMED", @@ -889,14 +889,14 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { checkError( exception = ex, - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map( "badRecord" -> "[null,11,{\"a\": \"1\", \"b\": 11}]", "failFastMode" -> "FAILFAST") ) checkError( exception = ex.getCause.asInstanceOf[SparkRuntimeException], - errorClass = "CANNOT_PARSE_JSON_FIELD", + condition = "CANNOT_PARSE_JSON_FIELD", parameters = Map( "fieldName" -> toSQLValue("a", StringType), "fieldValue" -> "1", @@ -973,7 +973,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { Seq(("""{"i":1}""", "i int")).toDF("json", "schema") .select(from_json($"json", $"schema", options)).collect() }, - errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", + condition = "INVALID_SCHEMA.NON_STRING_LITERAL", parameters = Map("inputSchema" -> "\"schema\""), context = ExpectedContext(fragment = "from_json", getCurrentClassCallSitePattern) ) @@ -1208,7 +1208,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(from_json($"json", invalidJsonSchema, Map.empty[String, String])).collect() }, - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'{'", "hint" -> ""), ExpectedContext("from_json", getCurrentClassCallSitePattern) ) @@ -1218,7 +1218,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(from_json($"json", invalidDataType, Map.empty[String, String])).collect() }, - errorClass = "UNSUPPORTED_DATATYPE", + condition = "UNSUPPORTED_DATATYPE", parameters = Map("typeName" -> "\"COW\""), ExpectedContext("from_json", getCurrentClassCallSitePattern) ) @@ -1228,7 +1228,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(from_json($"json", invalidTableSchema, Map.empty[String, String])).collect() }, - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'INT'", "hint" -> ""), ExpectedContext("from_json", getCurrentClassCallSitePattern) ) @@ -1247,7 +1247,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkException] { df.select(from_json($"value", schema, Map("mode" -> "FAILFAST"))("b")).collect() }, - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map( "badRecord" -> "[null,null]", "failFastMode" -> "FAILFAST") @@ -1257,7 +1257,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkException] { df.select(from_json($"value", schema, Map("mode" -> "FAILFAST"))("a")).collect() }, - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map( "badRecord" -> "[null,null]", "failFastMode" -> "FAILFAST") @@ -1279,7 +1279,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkException] { df.select(from_json($"value", schema, Map("mode" -> "FAILFAST"))("b")).collect() }, - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map( "badRecord" -> "[null]", "failFastMode" -> "FAILFAST") @@ -1289,7 +1289,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkException] { df.select(from_json($"value", schema, Map("mode" -> "FAILFAST"))("a")).collect() }, - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map( "badRecord" -> "[null]", "failFastMode" -> "FAILFAST") @@ -1401,7 +1401,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select($"a").withColumn("c", to_json(structData)).collect() }, - errorClass = "DATATYPE_MISMATCH.CANNOT_CONVERT_TO_JSON", + condition = "DATATYPE_MISMATCH.CANNOT_CONVERT_TO_JSON", parameters = Map( "sqlExpr" -> "\"to_json(NAMED_STRUCT('b', 1))\"", "name" -> "`b`", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index 336cf12ae57c5..9afba65183974 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -184,7 +184,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { query: String, parameters: Map[String, String]): Unit = { checkError( exception = intercept[AnalysisException] {sql(query)}, - errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", + condition = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", sqlState = "42702", parameters = parameters ) @@ -194,7 +194,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { query: String, lca: String, windowExprRegex: String): Unit = { checkErrorMatchPVals( exception = intercept[AnalysisException] {sql(query)}, - errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_WINDOW", + condition = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_WINDOW", parameters = Map("lca" -> lca, "windowExpr" -> windowExprRegex) ) } @@ -209,11 +209,14 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { } private def checkSameError( - q1: String, q2: String, errorClass: String, errorParams: Map[String, String]): Unit = { + q1: String, + q2: String, + condition: String, + errorParams: Map[String, String]): Unit = { val e1 = intercept[AnalysisException] { sql(q1) } val e2 = intercept[AnalysisException] { sql(q2) } - assert(e1.getErrorClass == errorClass) - assert(e2.getErrorClass == errorClass) + assert(e1.getErrorClass == condition) + assert(e2.getErrorClass == condition) errorParams.foreach { case (k, v) => assert(e1.messageParameters.get(k).exists(_ == v)) assert(e2.messageParameters.get(k).exists(_ == v)) @@ -258,7 +261,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { exception = intercept[AnalysisException] { sql(s"SELECT 10000 AS lca, count(lca) FROM $testTable GROUP BY dept") }, - errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", + condition = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", sqlState = "0A000", parameters = Map( "lca" -> "`lca`", @@ -269,7 +272,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { exception = intercept[AnalysisException] { sql(s"SELECT dept AS lca, avg(lca) FROM $testTable GROUP BY dept") }, - errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", + condition = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", sqlState = "0A000", parameters = Map( "lca" -> "`lca`", @@ -281,7 +284,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { exception = intercept[AnalysisException] { sql(s"SELECT sum(salary) AS a, avg(a) FROM $testTable") }, - errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", + condition = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", sqlState = "0A000", parameters = Map( "lca" -> "`a`", @@ -518,7 +521,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { exception = intercept[AnalysisException] { sql(query2) }, - errorClass = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", sqlState = "42703", parameters = Map("objectName" -> s"`id1`"), context = ExpectedContext( @@ -796,7 +799,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { exception = intercept[AnalysisException] { sql(s"SELECT dept AS d, d AS new_dept, new_dep + 1 AS newer_dept FROM $testTable") }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map("objectName" -> s"`new_dep`", "proposal" -> "`dept`, `name`, `bonus`, `salary`, `properties`"), @@ -809,7 +812,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { exception = intercept[AnalysisException] { sql(s"SELECT count(name) AS cnt, cnt + 1, count(unresovled) FROM $testTable GROUP BY dept") }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map("objectName" -> s"`unresovled`", "proposal" -> "`name`, `bonus`, `dept`, `properties`, `salary`"), @@ -823,7 +826,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { sql(s"SELECT * FROM range(1, 7) WHERE (" + s"SELECT id2 FROM (SELECT 1 AS id, other_id + 1 AS id2)) > 5") }, - errorClass = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", sqlState = "42703", parameters = Map("objectName" -> s"`other_id`"), context = ExpectedContext( @@ -898,7 +901,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { exception = intercept[AnalysisException] { sql( "SELECT dept AS a, dept, " + s"(SELECT count(col) FROM VALUES (1), (2) AS data(col) WHERE col = dept) $groupBySeg") }, - errorClass = "SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION", + condition = "SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION", parameters = Map("sqlExpr" -> "\"scalarsubquery(dept)\""), context = ExpectedContext( fragment = "(SELECT count(col) FROM VALUES (1), (2) AS data(col) WHERE col = dept)", @@ -910,7 +913,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { "SELECT dept AS a, a, " + s"(SELECT count(col) FROM VALUES (1), (2) AS data(col) WHERE col = dept) $groupBySeg" ) }, - errorClass = "SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION", + condition = "SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION", parameters = Map("sqlExpr" -> "\"scalarsubquery(dept)\""), context = ExpectedContext( fragment = "(SELECT count(col) FROM VALUES (1), (2) AS data(col) WHERE col = dept)", @@ -924,7 +927,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { exception = intercept[AnalysisException] { sql(s"SELECT avg(salary) AS a, avg(a) $windowExpr $groupBySeg") }, - errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", + condition = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", sqlState = "0A000", parameters = Map("lca" -> "`a`", "aggFunc" -> "\"avg(lateralAliasReference(a))\"") ) @@ -1009,7 +1012,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { "(partition by dept order by salary rows between n preceding and current row) as rank " + s"from $testTable where dept in (1, 6)") }, - errorClass = "_LEGACY_ERROR_TEMP_0064", + condition = "_LEGACY_ERROR_TEMP_0064", parameters = Map("msg" -> "Frame bound value must be a literal."), context = ExpectedContext(fragment = "n preceding", start = 87, stop = 97) ) @@ -1188,7 +1191,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { s"from $testTable", s"select dept as d, d, rank() over (partition by dept order by avg(salary)) " + s"from $testTable", - errorClass = "MISSING_GROUP_BY", + condition = "MISSING_GROUP_BY", errorParams = Map.empty ) checkSameError( @@ -1196,7 +1199,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { s"from $testTable", "select salary as s, s, sum(sum(salary)) over (partition by dept order by salary) " + s"from $testTable", - errorClass = "MISSING_GROUP_BY", + condition = "MISSING_GROUP_BY", errorParams = Map.empty ) @@ -1338,7 +1341,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { |""".stripMargin ) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`Freq`", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala index cc0cce08162ae..ebca6b26fce95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala @@ -56,7 +56,7 @@ abstract class MetadataCacheSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkException] { df.count() }, - errorClass = "FAILED_READ_FILE.FILE_NOT_EXIST", + condition = "FAILED_READ_FILE.FILE_NOT_EXIST", parameters = Map("path" -> ".*") ) } @@ -87,7 +87,7 @@ class MetadataCacheV1Suite extends MetadataCacheSuite { exception = intercept[SparkException] { sql("select count(*) from view_refresh").first() }, - errorClass = "FAILED_READ_FILE.FILE_NOT_EXIST", + condition = "FAILED_READ_FILE.FILE_NOT_EXIST", parameters = Map("path" -> ".*") ) @@ -115,7 +115,7 @@ class MetadataCacheV1Suite extends MetadataCacheSuite { exception = intercept[SparkException] { sql("select count(*) from view_refresh").first() }, - errorClass = "FAILED_READ_FILE.FILE_NOT_EXIST", + condition = "FAILED_READ_FILE.FILE_NOT_EXIST", parameters = Map("path" -> ".*") ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala index b9daece4913f2..b95b7b9d4c00d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala @@ -61,7 +61,7 @@ class MiscFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(sql(s"select $func"), Row(user)) checkError( exception = intercept[ParseException](sql(s"select $func()")), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> s"'$func'", "hint" -> "")) } } @@ -238,7 +238,7 @@ class MiscFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("reflect(cast(null as string), 'fromString', a)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL", + condition = "DATATYPE_MISMATCH.UNEXPECTED_NULL", parameters = Map( "exprName" -> "`class`", "sqlExpr" -> "\"reflect(CAST(NULL AS STRING), fromString, a)\""), @@ -247,7 +247,7 @@ class MiscFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("reflect('java.util.UUID', cast(null as string), a)") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL", + condition = "DATATYPE_MISMATCH.UNEXPECTED_NULL", parameters = Map( "exprName" -> "`method`", "sqlExpr" -> "\"reflect(java.util.UUID, CAST(NULL AS STRING), a)\""), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/NestedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/NestedDataSourceSuite.scala index f83e7b6727b16..f570fc3ab25f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/NestedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/NestedDataSourceSuite.scala @@ -65,7 +65,7 @@ trait NestedDataSourceSuiteBase extends QueryTest with SharedSparkSession { .load(path) .collect() }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`camelcase`") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala index be3669cc62023..c90b34d45e783 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala @@ -73,7 +73,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { exception = intercept[AnalysisException] { spark.sql("select :P", Map("p" -> 1)) }, - errorClass = "UNBOUND_SQL_PARAMETER", + condition = "UNBOUND_SQL_PARAMETER", parameters = Map("name" -> "P"), context = ExpectedContext( fragment = ":P", @@ -246,8 +246,8 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { exception = intercept[ParseException] { spark.sql(sqlText, args) }, - errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", - parameters = Map("statement" -> "CREATE VIEW body"), + condition = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", + parameters = Map("statement" -> "the query of CREATE VIEW"), context = ExpectedContext( fragment = sqlText, start = 0, @@ -261,8 +261,8 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { exception = intercept[ParseException] { spark.sql(sqlText, args) }, - errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", - parameters = Map("statement" -> "CREATE VIEW body"), + condition = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", + parameters = Map("statement" -> "the query of CREATE VIEW"), context = ExpectedContext( fragment = sqlText, start = 0, @@ -276,8 +276,8 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { exception = intercept[ParseException] { spark.sql(sqlText, args) }, - errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", - parameters = Map("statement" -> "CREATE VIEW body"), + condition = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", + parameters = Map("statement" -> "the query of CREATE VIEW"), context = ExpectedContext( fragment = sqlText, start = 0, @@ -291,8 +291,8 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { exception = intercept[ParseException] { spark.sql(sqlText, args) }, - errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", - parameters = Map("statement" -> "CREATE VIEW body"), + condition = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", + parameters = Map("statement" -> "the query of CREATE VIEW"), context = ExpectedContext( fragment = sqlText, start = 0, @@ -310,8 +310,8 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { exception = intercept[ParseException] { spark.sql(sqlText, args) }, - errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", - parameters = Map("statement" -> "CREATE VIEW body"), + condition = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", + parameters = Map("statement" -> "the query of CREATE VIEW"), context = ExpectedContext( fragment = sqlText, start = 0, @@ -329,8 +329,8 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { exception = intercept[ParseException] { spark.sql(sqlText, args) }, - errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", - parameters = Map("statement" -> "CREATE VIEW body"), + condition = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", + parameters = Map("statement" -> "the query of CREATE VIEW"), context = ExpectedContext( fragment = sqlText, start = 0, @@ -342,7 +342,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { exception = intercept[AnalysisException] { spark.sql("select :abc, :def", Map("abc" -> 1)) }, - errorClass = "UNBOUND_SQL_PARAMETER", + condition = "UNBOUND_SQL_PARAMETER", parameters = Map("name" -> "def"), context = ExpectedContext( fragment = ":def", @@ -352,7 +352,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { exception = intercept[AnalysisException] { sql("select :abc").collect() }, - errorClass = "UNBOUND_SQL_PARAMETER", + condition = "UNBOUND_SQL_PARAMETER", parameters = Map("name" -> "abc"), context = ExpectedContext( fragment = ":abc", @@ -365,7 +365,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { exception = intercept[AnalysisException] { spark.sql("select ?, ?", Array(1)) }, - errorClass = "UNBOUND_SQL_PARAMETER", + condition = "UNBOUND_SQL_PARAMETER", parameters = Map("name" -> "_10"), context = ExpectedContext( fragment = "?", @@ -375,7 +375,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { exception = intercept[AnalysisException] { sql("select ?").collect() }, - errorClass = "UNBOUND_SQL_PARAMETER", + condition = "UNBOUND_SQL_PARAMETER", parameters = Map("name" -> "_7"), context = ExpectedContext( fragment = "?", @@ -472,7 +472,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { exception = intercept[AnalysisException] { spark.sql("select :param1, ?", Map("param1" -> 1)) }, - errorClass = "UNBOUND_SQL_PARAMETER", + condition = "UNBOUND_SQL_PARAMETER", parameters = Map("name" -> "_16"), context = ExpectedContext( fragment = "?", @@ -483,7 +483,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { exception = intercept[AnalysisException] { spark.sql("select :param1, ?", Array(1)) }, - errorClass = "UNBOUND_SQL_PARAMETER", + condition = "UNBOUND_SQL_PARAMETER", parameters = Map("name" -> "param1"), context = ExpectedContext( fragment = ":param1", @@ -498,7 +498,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { "CREATE TABLE t11(c1 int default :parm) USING parquet", args = Map("parm" -> 5)) }, - errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", + condition = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", parameters = Map("statement" -> "DEFAULT"), context = ExpectedContext( fragment = "default :parm", @@ -602,7 +602,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { lit(Array("a")), array(str_to_map(lit("a:1,b:2,c:3")))))) }, - errorClass = "INVALID_SQL_ARG", + condition = "INVALID_SQL_ARG", parameters = Map("name" -> "m"), context = ExpectedContext( fragment = "map_from_arrays", @@ -715,4 +715,30 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { spark.sessionState.analyzer.executeAndCheck(analyzedPlan, df.queryExecution.tracker) checkAnswer(df, Row(11)) } + + test("SPARK-49398: Cache Table with parameter markers in select query should throw " + + "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT") { + val sqlText = "CACHE TABLE CacheTable as SELECT 1 + :param1" + checkError( + exception = intercept[AnalysisException] { + spark.sql(sqlText, Map("param1" -> "1")).show() + }, + condition = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", + parameters = Map("statement" -> "the query of CACHE TABLE"), + context = ExpectedContext( + fragment = sqlText, + start = 0, + stop = sqlText.length - 1) + ) + } + + test("SPARK-49398: Cache Table with parameter in identifier should work") { + val cacheName = "MyCacheTable" + withCache(cacheName) { + spark.sql("CACHE TABLE IDENTIFIER(:param) as SELECT 1 as c1", Map("param" -> cacheName)) + checkAnswer( + spark.sql("SHOW COLUMNS FROM IDENTIFIER(?)", args = Array(cacheName)), + Row("c1")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ResolveDefaultColumnsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ResolveDefaultColumnsSuite.scala index e3ebbadbb829a..cb9d0909554b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ResolveDefaultColumnsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ResolveDefaultColumnsSuite.scala @@ -40,7 +40,7 @@ class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t values (timestamp'2020-12-31')") }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "tableColumns" -> "`c1`, `c2`", @@ -68,7 +68,7 @@ class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t values (timestamp'2020-12-31')") }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "tableColumns" -> "`c1`, `c2`", @@ -85,7 +85,7 @@ class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t values (1, 2, 3)") }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "tableColumns" -> "`c1`, `c2`, `c3`, `c4`", @@ -102,7 +102,7 @@ class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t partition(c3=3, c4=4) values (1)") }, - errorClass = "INSERT_PARTITION_COLUMN_ARITY_MISMATCH", + condition = "INSERT_PARTITION_COLUMN_ARITY_MISMATCH", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "tableColumns" -> "`c1`, `c2`, `c3`, `c4`", @@ -120,7 +120,7 @@ class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t partition(c3=3, c4) values (1, 2)") }, - errorClass = "INSERT_PARTITION_COLUMN_ARITY_MISMATCH", + condition = "INSERT_PARTITION_COLUMN_ARITY_MISMATCH", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "tableColumns" -> "`c1`, `c2`, `c3`, `c4`", @@ -173,7 +173,7 @@ class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("create table demos.test_ts_other (a int default 'abc') using parquet") }, - errorClass = "INVALID_DEFAULT_VALUE.DATA_TYPE", + condition = "INVALID_DEFAULT_VALUE.DATA_TYPE", parameters = Map( "statement" -> "CREATE TABLE", "colName" -> "`a`", @@ -184,7 +184,7 @@ class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("create table demos.test_ts_other (a timestamp default 'invalid') using parquet") }, - errorClass = "INVALID_DEFAULT_VALUE.DATA_TYPE", + condition = "INVALID_DEFAULT_VALUE.DATA_TYPE", parameters = Map( "statement" -> "CREATE TABLE", "colName" -> "`a`", @@ -195,7 +195,7 @@ class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("create table demos.test_ts_other (a boolean default 'true') using parquet") }, - errorClass = "INVALID_DEFAULT_VALUE.DATA_TYPE", + condition = "INVALID_DEFAULT_VALUE.DATA_TYPE", parameters = Map( "statement" -> "CREATE TABLE", "colName" -> "`a`", @@ -206,7 +206,7 @@ class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("create table demos.test_ts_other (a int default true) using parquet") }, - errorClass = "INVALID_DEFAULT_VALUE.DATA_TYPE", + condition = "INVALID_DEFAULT_VALUE.DATA_TYPE", parameters = Map( "statement" -> "CREATE TABLE", "colName" -> "`a`", @@ -237,7 +237,7 @@ class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { checkError( exception = intercept[SparkRuntimeException]( sql(s"CREATE TABLE t(c $typeName(3) DEFAULT 'spark') USING parquet")), - errorClass = "EXCEED_LIMIT_LENGTH", + condition = "EXCEED_LIMIT_LENGTH", parameters = Map("limit" -> "3")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 4c4560e3fc48b..5de4170a1c112 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -119,7 +119,7 @@ class RowSuite extends SparkFunSuite with SharedSparkSession { exception = intercept[SparkUnsupportedOperationException] { rowWithoutSchema.fieldIndex("foo") }, - errorClass = "UNSUPPORTED_CALL.FIELD_INDEX", + condition = "UNSUPPORTED_CALL.FIELD_INDEX", parameters = Map("methodName" -> "fieldIndex", "className" -> "Row", "fieldName" -> "`foo`") ) } @@ -132,7 +132,7 @@ class RowSuite extends SparkFunSuite with SharedSparkSession { exception = intercept[SparkRuntimeException] { rowWithNullValue.getLong(position) }, - errorClass = "ROW_VALUE_IS_NULL", + condition = "ROW_VALUE_IS_NULL", parameters = Map("index" -> position.toString) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala index 4052130720811..352197f96acb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.internal.config +import org.apache.spark.sql.internal.RuntimeConfigImpl import org.apache.spark.sql.internal.SQLConf.CHECKPOINT_LOCATION import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE class RuntimeConfigSuite extends SparkFunSuite { - private def newConf(): RuntimeConfig = new RuntimeConfig + private def newConf(): RuntimeConfig = new RuntimeConfigImpl() test("set and get") { val conf = newConf() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala index 63ed26bdeddf1..170105200f1d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala @@ -276,7 +276,7 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils with AdaptiveSparkP checkError( exception = intercept[AnalysisException]( sql(s"INSERT INTO t1 (c1, c2, c2) values(1, 2, 3)")), - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`c2`")) } } @@ -288,7 +288,7 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils with AdaptiveSparkP checkError( exception = intercept[AnalysisException](sql(s"INSERT INTO t1 (c1, c2, c4) values(1, 2, 3)")), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map("objectName" -> "`c4`", "proposal" -> "`c1`, `c2`, `c3`"), context = ExpectedContext( @@ -307,7 +307,7 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils with AdaptiveSparkP sql(s"INSERT INTO t1 (c1, c2) values(1, 2, 3)") }, sqlState = None, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS", parameters = Map( "tableName" -> ".*`t1`", "tableColumns" -> "`c1`, `c2`", @@ -319,7 +319,7 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils with AdaptiveSparkP sql(s"INSERT INTO t1 (c1, c2, c3) values(1, 2)") }, sqlState = None, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", parameters = Map( "tableName" -> ".*`t1`", "tableColumns" -> "`c1`, `c2`, `c3`", @@ -399,7 +399,7 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils with AdaptiveSparkP sql("INSERT OVERWRITE t PARTITION (c='2', C='3') VALUES (1)") }, sqlState = None, - errorClass = "DUPLICATE_KEY", + condition = "DUPLICATE_KEY", parameters = Map("keyColumn" -> "`c`"), context = ExpectedContext("PARTITION (c='2', C='3')", 19, 42) ) @@ -441,12 +441,11 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils with AdaptiveSparkP exception = intercept[SparkNumberFormatException] { sql("insert into t partition(a='ansi') values('ansi')") }, - errorClass = "CAST_INVALID_INPUT", + condition = "CAST_INVALID_INPUT", parameters = Map( "expression" -> "'ansi'", "sourceType" -> "\"STRING\"", - "targetType" -> "\"INT\"", - "ansiConfig" -> "\"spark.sql.ansi.enabled\"" + "targetType" -> "\"INT\"" ), context = ExpectedContext("insert into t partition(a='ansi')", 0, 32) ) @@ -492,7 +491,7 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils with AdaptiveSparkP exception = intercept[AnalysisException] { sql("alter table t drop partition(dt='8')") }, - errorClass = "PARTITIONS_NOT_FOUND", + condition = "PARTITIONS_NOT_FOUND", sqlState = None, parameters = Map( "partitionList" -> "PARTITION \\(`dt` = 8\\)", @@ -512,7 +511,7 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils with AdaptiveSparkP exception = intercept[AnalysisException] { sql("alter table t drop partition(dt='08')") }, - errorClass = "PARTITIONS_NOT_FOUND", + condition = "PARTITIONS_NOT_FOUND", sqlState = None, parameters = Map( "partitionList" -> "PARTITION \\(`dt` = 08\\)", @@ -562,7 +561,7 @@ class FileSourceSQLInsertTestSuite extends SQLInsertTestSuite with SharedSparkSe v2ErrorClass: String, v1Parameters: Map[String, String], v2Parameters: Map[String, String]): Unit = { - checkError(exception = exception, sqlState = None, errorClass = v1ErrorClass, + checkError(exception = exception, sqlState = None, condition = v1ErrorClass, parameters = v1Parameters) } @@ -582,7 +581,7 @@ class DSV2SQLInsertTestSuite extends SQLInsertTestSuite with SharedSparkSession v2ErrorClass: String, v1Parameters: Map[String, String], v2Parameters: Map[String, String]): Unit = { - checkError(exception = exception, sqlState = None, errorClass = v2ErrorClass, + checkError(exception = exception, sqlState = None, condition = v2ErrorClass, parameters = v2Parameters) } protected override def sparkConf: SparkConf = { @@ -598,7 +597,7 @@ class DSV2SQLInsertTestSuite extends SQLInsertTestSuite with SharedSparkSession exception = intercept[AnalysisException] { sql("INSERT OVERWRITE t PARTITION (c='1') (c) VALUES ('2')") }, - errorClass = "STATIC_PARTITION_COLUMN_IN_INSERT_COLUMN_LIST", + condition = "STATIC_PARTITION_COLUMN_IN_INSERT_COLUMN_LIST", parameters = Map("staticName" -> "c")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 55313c8ac2f86..8176d02dbd02d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -100,7 +100,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark val sqlText = "describe functioN abcadf" checkError( exception = intercept[AnalysisException](sql(sqlText)), - errorClass = "UNRESOLVED_ROUTINE", + condition = "UNRESOLVED_ROUTINE", parameters = Map( "routineName" -> "`abcadf`", "searchPath" -> "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`default`]"), @@ -111,10 +111,13 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } test("SPARK-34678: describe functions for table-valued functions") { + sql("describe function range").show(false) checkKeywordsExist(sql("describe function range"), "Function: range", "Class: org.apache.spark.sql.catalyst.plans.logical.Range", - "range(end: long)" + "range(start[, end[, step[, numSlices]]])", + "range(end)", + "Returns a table of values within a specified range." ) } @@ -1659,7 +1662,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark exception = intercept[AnalysisException] { sql("select * from json.invalid_file") }, - errorClass = "PATH_NOT_FOUND", + condition = "PATH_NOT_FOUND", parameters = Map("path" -> "file:/.*invalid_file"), matchPVals = true ) @@ -1668,7 +1671,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark exception = intercept[AnalysisException] { sql(s"select id from `org.apache.spark.sql.hive.orc`.`file_path`") }, - errorClass = "_LEGACY_ERROR_TEMP_1138" + condition = "_LEGACY_ERROR_TEMP_1138" ) e = intercept[AnalysisException] { @@ -1833,7 +1836,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark exception = intercept[AnalysisException]{ sql("SELECT abc.* FROM nestedStructTable") }, - errorClass = "CANNOT_RESOLVE_STAR_EXPAND", + condition = "CANNOT_RESOLVE_STAR_EXPAND", parameters = Map("targetString" -> "`abc`", "columns" -> "`record`"), context = ExpectedContext(fragment = "abc.*", start = 7, stop = 11)) } @@ -1868,7 +1871,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark exception = intercept[AnalysisException]{ sql("select a.* from testData2") }, - errorClass = "_LEGACY_ERROR_TEMP_1050", + condition = "_LEGACY_ERROR_TEMP_1050", sqlState = None, parameters = Map("attributes" -> "(ArrayBuffer|List)\\(a\\)"), matchPVals = true, @@ -1922,7 +1925,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark exception = intercept[AnalysisException] { sql("SELECT a.* FROM temp_table_no_cols a") }, - errorClass = "CANNOT_RESOLVE_STAR_EXPAND", + condition = "CANNOT_RESOLVE_STAR_EXPAND", parameters = Map("targetString" -> "`a`", "columns" -> ""), context = ExpectedContext(fragment = "a.*", start = 7, stop = 9)) @@ -1930,7 +1933,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark exception = intercept[AnalysisException] { dfNoCols.select($"b.*") }, - errorClass = "CANNOT_RESOLVE_STAR_EXPAND", + condition = "CANNOT_RESOLVE_STAR_EXPAND", parameters = Map("targetString" -> "`b`", "columns" -> ""), context = ExpectedContext( fragment = "$", @@ -2568,20 +2571,21 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath) val newSession = spark.newSession() + val newSqlConf = newSession.sessionState.conf val originalValue = newSession.sessionState.conf.runSQLonFile try { - newSession.conf.set(SQLConf.RUN_SQL_ON_FILES, false) + newSqlConf.setConf(SQLConf.RUN_SQL_ON_FILES, false) intercept[AnalysisException] { newSession.sql(s"SELECT i, j FROM parquet.`${path.getCanonicalPath}`") } - newSession.conf.set(SQLConf.RUN_SQL_ON_FILES, true) + newSqlConf.setConf(SQLConf.RUN_SQL_ON_FILES, true) checkAnswer( newSession.sql(s"SELECT i, j FROM parquet.`${path.getCanonicalPath}`"), Row(1, "a")) } finally { - newSession.conf.set(SQLConf.RUN_SQL_ON_FILES, originalValue) + newSqlConf.setConf(SQLConf.RUN_SQL_ON_FILES, originalValue) } } } @@ -2677,7 +2681,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark exception = intercept[AnalysisException] { sql("SELECT nvl(1, 2, 3)") }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> toSQLId("nvl"), "expectedNum" -> "2", @@ -2738,7 +2742,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark exception = intercept[AnalysisException] { sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))") }, - errorClass = "INCOMPATIBLE_COLUMN_TYPE", + condition = "INCOMPATIBLE_COLUMN_TYPE", parameters = Map( "tableOrdinalNumber" -> "second", "columnOrdinalNumber" -> "first", @@ -2761,7 +2765,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark exception = intercept[AnalysisException] { sql(query) }, - errorClass = "DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES", + condition = "DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES", sqlState = None, parameters = Map( "sqlExpr" -> "\"(c = C)\"", @@ -3073,7 +3077,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(sql("select s.I from t group by s.i"), Nil) } }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`I`", "fields" -> "`i`"), context = ExpectedContext( fragment = "s.I", @@ -3784,7 +3788,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark exception = intercept[AnalysisException] { sql("SELECT s LIKE 'm%@ca' ESCAPE '%' FROM df").collect() }, - errorClass = "INVALID_FORMAT.ESC_IN_THE_MIDDLE", + condition = "INVALID_FORMAT.ESC_IN_THE_MIDDLE", parameters = Map( "format" -> toSQLValue("m%@ca", StringType), "char" -> toSQLValue("@", StringType))) @@ -3801,7 +3805,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark exception = intercept[AnalysisException] { sql("SELECT a LIKE 'jialiuping%' ESCAPE '%' FROM df").collect() }, - errorClass = "INVALID_FORMAT.ESC_AT_THE_END", + condition = "INVALID_FORMAT.ESC_AT_THE_END", parameters = Map("format" -> toSQLValue("jialiuping%", StringType))) } } @@ -3901,7 +3905,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark exception = intercept[AnalysisException] { sql(s"CREATE TEMPORARY FUNCTION $functionName AS '$sumFuncClass'") }, - errorClass = "CANNOT_LOAD_FUNCTION_CLASS", + condition = "CANNOT_LOAD_FUNCTION_CLASS", parameters = Map( "className" -> "org.apache.spark.examples.sql.Spark33084", "functionName" -> "`test_udf`" @@ -3996,7 +4000,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } checkError( exception = e, - errorClass = "INVALID_TEMP_OBJ_REFERENCE", + condition = "INVALID_TEMP_OBJ_REFERENCE", parameters = Map( "obj" -> "VIEW", "objName" -> s"`$SESSION_CATALOG_NAME`.`default`.`$testViewName`", @@ -4015,7 +4019,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } checkError( exception = e2, - errorClass = "INVALID_TEMP_OBJ_REFERENCE", + condition = "INVALID_TEMP_OBJ_REFERENCE", parameters = Map( "obj" -> "VIEW", "objName" -> s"`$SESSION_CATALOG_NAME`.`default`.`$testViewName`", @@ -4901,7 +4905,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark exception = intercept[AnalysisException]( sql(sqlText) ), - errorClass = "MISSING_WINDOW_SPECIFICATION", + condition = "MISSING_WINDOW_SPECIFICATION", parameters = Map( "windowName" -> "unspecified_window", "docroot" -> SPARK_DOC_ROOT @@ -4909,6 +4913,21 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark ) } } + + test("SPARK-49659: Unsupported scalar subqueries in VALUES") { + checkError( + exception = intercept[AnalysisException]( + sql("SELECT * FROM VALUES ((SELECT 1) + (SELECT 2))") + ), + condition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.SCALAR_SUBQUERY_IN_VALUES", + parameters = Map(), + context = ExpectedContext( + fragment = "VALUES ((SELECT 1) + (SELECT 2))", + start = 14, + stop = 45 + ) + ) + } } case class Foo(bar: Option[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 01b9fdec9be3d..16118526f2fe4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -163,7 +163,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSparkSession exception = intercept[SparkUnsupportedOperationException] { Seq(InvalidInJava(1)).toDS() }, - errorClass = "_LEGACY_ERROR_TEMP_2140", + condition = "_LEGACY_ERROR_TEMP_2140", parameters = Map( "fieldName" -> "abstract", "walkedTypePath" -> "- root class: \"org.apache.spark.sql.InvalidInJava\"")) @@ -174,7 +174,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSparkSession exception = intercept[SparkUnsupportedOperationException] { Seq(InvalidInJava2(1)).toDS() }, - errorClass = "_LEGACY_ERROR_TEMP_2140", + condition = "_LEGACY_ERROR_TEMP_2140", parameters = Map( "fieldName" -> "0", "walkedTypePath" -> diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SetCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SetCommandSuite.scala index a8b359f308a2b..f4ea87b39c39b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SetCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SetCommandSuite.scala @@ -139,7 +139,7 @@ class SetCommandSuite extends QueryTest with SharedSparkSession with ResetSystem withSQLConf(key1 -> value1) { checkError( intercept[ParseException](sql("SET ${test.password}")), - errorClass = "INVALID_SET_SYNTAX" + condition = "INVALID_SET_SYNTAX" ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index 4ac05373e5a34..d3117ec411feb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -201,10 +201,10 @@ class SparkSessionBuilderSuite extends SparkFunSuite with Eventually { .getOrCreate() assert(session.conf.get("spark.app.name") === "test-app-SPARK-31234") - assert(session.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31234") + assert(session.conf.get(GLOBAL_TEMP_DATABASE.key) === "globalTempDB-SPARK-31234") session.sql("RESET") assert(session.conf.get("spark.app.name") === "test-app-SPARK-31234") - assert(session.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31234") + assert(session.conf.get(GLOBAL_TEMP_DATABASE.key) === "globalTempDB-SPARK-31234") } test("SPARK-31354: SparkContext only register one SparkSession ApplicationEnd listener") { @@ -244,8 +244,8 @@ class SparkSessionBuilderSuite extends SparkFunSuite with Eventually { .builder() .config(GLOBAL_TEMP_DATABASE.key, "globalTempDB-SPARK-31532-1") .getOrCreate() - assert(session.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31532") - assert(session1.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31532") + assert(session.conf.get(GLOBAL_TEMP_DATABASE.key) === "globalTempDB-SPARK-31532") + assert(session1.conf.get(GLOBAL_TEMP_DATABASE.key) === "globalTempDB-SPARK-31532") // do not propagate static sql configs to the existing default session SparkSession.clearActiveSession() @@ -255,9 +255,9 @@ class SparkSessionBuilderSuite extends SparkFunSuite with Eventually { .config(GLOBAL_TEMP_DATABASE.key, value = "globalTempDB-SPARK-31532-2") .getOrCreate() - assert(!session.conf.get(WAREHOUSE_PATH).contains("SPARK-31532-db")) - assert(session.conf.get(WAREHOUSE_PATH) === session2.conf.get(WAREHOUSE_PATH)) - assert(session2.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31532") + assert(!session.conf.get(WAREHOUSE_PATH.key).contains("SPARK-31532-db")) + assert(session.conf.get(WAREHOUSE_PATH.key) === session2.conf.get(WAREHOUSE_PATH.key)) + assert(session2.conf.get(GLOBAL_TEMP_DATABASE.key) === "globalTempDB-SPARK-31532") } test("SPARK-31532: propagate static sql configs if no existing SparkSession") { @@ -275,8 +275,8 @@ class SparkSessionBuilderSuite extends SparkFunSuite with Eventually { .config(WAREHOUSE_PATH.key, "SPARK-31532-db-2") .getOrCreate() assert(session.conf.get("spark.app.name") === "test-app-SPARK-31532-2") - assert(session.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31532-2") - assert(session.conf.get(WAREHOUSE_PATH) contains "SPARK-31532-db-2") + assert(session.conf.get(GLOBAL_TEMP_DATABASE.key) === "globalTempDB-SPARK-31532-2") + assert(session.conf.get(WAREHOUSE_PATH.key) contains "SPARK-31532-db-2") } test("SPARK-32062: reset listenerRegistered in SparkSession") { @@ -461,7 +461,7 @@ class SparkSessionBuilderSuite extends SparkFunSuite with Eventually { val expected = path.getFileSystem(hadoopConf).makeQualified(path).toString // session related configs assert(hadoopConf.get("hive.metastore.warehouse.dir") === expected) - assert(session.conf.get(WAREHOUSE_PATH) === expected) + assert(session.conf.get(WAREHOUSE_PATH.key) === expected) assert(session.sessionState.conf.warehousePath === expected) // shared configs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 322210bf5b59f..ba87028a71477 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -178,7 +178,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt MyColumnarRule(MyNewQueryStageRule(), MyNewQueryStageRule())) } withSession(extensions) { session => - session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED, true) + session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, true) assert(session.sessionState.adaptiveRulesHolder.queryStagePrepRules .contains(MyQueryStagePrepRule())) assert(session.sessionState.columnarRules.contains( @@ -221,7 +221,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())) } withSession(extensions) { session => - session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED, true) + session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, true) session.conf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") assert(session.sessionState.columnarRules.contains( MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule()))) @@ -280,7 +280,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())) } withSession(extensions) { session => - session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED, enableAQE) + session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, enableAQE) assert(session.sessionState.columnarRules.contains( MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule()))) import session.implicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala new file mode 100644 index 0000000000000..e9fd07ecf18b7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.{ConcurrentHashMap, Semaphore, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger + +import scala.concurrent.{ExecutionContext, Future} +import scala.jdk.CollectionConverters._ + +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerJobStart} +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.tags.ExtendedSQLTest +import org.apache.spark.util.ThreadUtils + +/** + * Test cases for the tagging and cancellation APIs provided by [[SparkSession]]. + */ +@ExtendedSQLTest +class SparkSessionJobTaggingAndCancellationSuite + extends SparkFunSuite + with Eventually + with LocalSparkContext { + + override def afterEach(): Unit = { + try { + // This suite should not interfere with the other test suites. + SparkSession.getActiveSession.foreach(_.stop()) + SparkSession.clearActiveSession() + SparkSession.getDefaultSession.foreach(_.stop()) + SparkSession.clearDefaultSession() + resetSparkContext() + } finally { + super.afterEach() + } + } + + test("Tags are not inherited by new sessions") { + val session = SparkSession.builder().master("local").getOrCreate() + + assert(session.getTags() == Set()) + session.addTag("one") + assert(session.getTags() == Set("one")) + + val newSession = session.newSession() + assert(newSession.getTags() == Set()) + } + + test("Tags are inherited by cloned sessions") { + val session = SparkSession.builder().master("local").getOrCreate() + + assert(session.getTags() == Set()) + session.addTag("one") + assert(session.getTags() == Set("one")) + + val clonedSession = session.cloneSession() + assert(clonedSession.getTags() == Set("one")) + clonedSession.addTag("two") + assert(clonedSession.getTags() == Set("one", "two")) + + // Tags are not propagated back to the original session + assert(session.getTags() == Set("one")) + } + + test("Tags set from session are prefixed with session UUID") { + sc = new SparkContext("local[2]", "test") + val session = SparkSession.builder().sparkContext(sc).getOrCreate() + import session.implicits._ + + val sem = new Semaphore(0) + sc.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + sem.release() + } + }) + + session.addTag("one") + Future { + session.range(1, 10000).map { i => Thread.sleep(100); i }.count() + }(ExecutionContext.global) + + assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) + val activeJobsFuture = + session.sparkContext.cancelJobsWithTagWithFuture(session.managedJobTags.get("one"), "reason") + val activeJob = ThreadUtils.awaitResult(activeJobsFuture, 60.seconds).head + val actualTags = activeJob.properties.getProperty(SparkContext.SPARK_JOB_TAGS) + .split(SparkContext.SPARK_JOB_TAGS_SEP) + assert(actualTags.toSet == Set( + session.sessionJobTag, + s"${session.sessionJobTag}-one", + SQLExecution.executionIdJobTag( + session, + activeJob.properties.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong))) + } + + test("Cancellation APIs in SparkSession are isolated") { + sc = new SparkContext("local[2]", "test") + val globalSession = SparkSession.builder().sparkContext(sc).getOrCreate() + var (sessionA, sessionB, sessionC): (SparkSession, SparkSession, SparkSession) = + (null, null, null) + + // global ExecutionContext has only 2 threads in Apache Spark CI + // create own thread pool for four Futures used in this test + val numThreads = 3 + val fpool = ThreadUtils.newForkJoinPool("job-tags-test-thread-pool", numThreads) + val executionContext = ExecutionContext.fromExecutorService(fpool) + + try { + // Add a listener to release the semaphore once jobs are launched. + val sem = new Semaphore(0) + val jobEnded = new AtomicInteger(0) + val jobProperties: ConcurrentHashMap[Int, java.util.Properties] = new ConcurrentHashMap() + + sc.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobProperties.put(jobStart.jobId, jobStart.properties) + sem.release() + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + sem.release() + jobEnded.incrementAndGet() + } + }) + + // Note: since tags are added in the Future threads, they don't need to be cleared in between. + val jobA = Future { + sessionA = globalSession.cloneSession() + import globalSession.implicits._ + + assert(sessionA.getTags() == Set()) + sessionA.addTag("two") + assert(sessionA.getTags() == Set("two")) + sessionA.clearTags() // check that clearing all tags works + assert(sessionA.getTags() == Set()) + sessionA.addTag("one") + assert(sessionA.getTags() == Set("one")) + try { + sessionA.range(1, 10000).map { i => Thread.sleep(100); i }.count() + } finally { + sessionA.clearTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + val jobB = Future { + sessionB = globalSession.cloneSession() + import globalSession.implicits._ + + assert(sessionB.getTags() == Set()) + sessionB.addTag("one") + sessionB.addTag("two") + sessionB.addTag("one") + sessionB.addTag("two") // duplicates shouldn't matter + assert(sessionB.getTags() == Set("one", "two")) + try { + sessionB.range(1, 10000, 2).map { i => Thread.sleep(100); i }.count() + } finally { + sessionB.clearTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + val jobC = Future { + sessionC = globalSession.cloneSession() + import globalSession.implicits._ + + sessionC.addTag("foo") + sessionC.removeTag("foo") + assert(sessionC.getTags() == Set()) // check that remove works removing the last tag + sessionC.addTag("boo") + try { + sessionC.range(1, 10000, 2).map { i => Thread.sleep(100); i }.count() + } finally { + sessionC.clearTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + + // Block until four jobs have started. + assert(sem.tryAcquire(3, 1, TimeUnit.MINUTES)) + + // Tags are applied + assert(jobProperties.size == 3) + for (ss <- Seq(sessionA, sessionB, sessionC)) { + val jobProperty = jobProperties.values().asScala.filter(_.get(SparkContext.SPARK_JOB_TAGS) + .asInstanceOf[String].contains(ss.sessionUUID)) + assert(jobProperty.size == 1) + val tags = jobProperty.head.get(SparkContext.SPARK_JOB_TAGS).asInstanceOf[String] + .split(SparkContext.SPARK_JOB_TAGS_SEP) + + val executionRootIdTag = SQLExecution.executionIdJobTag( + ss, + jobProperty.head.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong) + val userTagsPrefix = s"spark-session-${ss.sessionUUID}-" + + ss match { + case s if s == sessionA => assert(tags.toSet == Set( + s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one")) + case s if s == sessionB => assert(tags.toSet == Set( + s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one", s"${userTagsPrefix}two")) + case s if s == sessionC => assert(tags.toSet == Set( + s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}boo")) + } + } + + // Global session cancels nothing + assert(globalSession.interruptAll().isEmpty) + assert(globalSession.interruptTag("one").isEmpty) + assert(globalSession.interruptTag("two").isEmpty) + for (i <- SQLExecution.executionIdToQueryExecution.keys().asScala) { + assert(globalSession.interruptOperation(i.toString).isEmpty) + } + assert(jobEnded.intValue == 0) + + // One job cancelled + for (i <- SQLExecution.executionIdToQueryExecution.keys().asScala) { + sessionC.interruptOperation(i.toString) + } + val eC = intercept[SparkException] { + ThreadUtils.awaitResult(jobC, 1.minute) + }.getCause + assert(eC.getMessage contains "cancelled") + assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) + assert(jobEnded.intValue == 1) + + // Another job cancelled + assert(sessionA.interruptTag("one").size == 1) + val eA = intercept[SparkException] { + ThreadUtils.awaitResult(jobA, 1.minute) + }.getCause + assert(eA.getMessage contains "cancelled job tags one") + assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) + assert(jobEnded.intValue == 2) + + // The last job cancelled + sessionB.interruptAll() + val eB = intercept[SparkException] { + ThreadUtils.awaitResult(jobB, 1.minute) + }.getCause + assert(eB.getMessage contains "cancellation of all jobs") + assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) + assert(jobEnded.intValue == 3) + } finally { + fpool.shutdownNow() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 919958d304f10..948a0e3444cd1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -76,14 +76,14 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared exception = intercept[AnalysisException] { sql(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") }, - errorClass = "UNSUPPORTED_FEATURE.ANALYZE_VIEW", + condition = "UNSUPPORTED_FEATURE.ANALYZE_VIEW", parameters = Map.empty ) checkError( exception = intercept[AnalysisException] { sql(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") }, - errorClass = "UNSUPPORTED_FEATURE.ANALYZE_VIEW", + condition = "UNSUPPORTED_FEATURE.ANALYZE_VIEW", parameters = Map.empty ) } @@ -136,7 +136,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared exception = intercept[AnalysisException] { sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data") }, - errorClass = "UNSUPPORTED_FEATURE.ANALYZE_UNSUPPORTED_COLUMN_TYPE", + condition = "UNSUPPORTED_FEATURE.ANALYZE_UNSUPPORTED_COLUMN_TYPE", parameters = Map( "columnType" -> "\"ARRAY\"", "columnName" -> "`data`", @@ -149,7 +149,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared exception = intercept[AnalysisException] { sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS some_random_column") }, - errorClass = "COLUMN_NOT_FOUND", + condition = "COLUMN_NOT_FOUND", parameters = Map( "colName" -> "`some_random_column`", "caseSensitiveConfig" -> "\"spark.sql.caseSensitive\"" @@ -630,7 +630,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared exception = intercept[AnalysisException] { sql("ANALYZE TABLE tempView COMPUTE STATISTICS FOR COLUMNS id") }, - errorClass = "UNSUPPORTED_FEATURE.ANALYZE_UNCACHED_TEMP_VIEW", + condition = "UNSUPPORTED_FEATURE.ANALYZE_UNCACHED_TEMP_VIEW", parameters = Map("viewName" -> "`tempView`") ) @@ -656,7 +656,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared exception = intercept[AnalysisException] { sql(s"ANALYZE TABLE $globalTempDB.gTempView COMPUTE STATISTICS FOR COLUMNS id") }, - errorClass = "UNSUPPORTED_FEATURE.ANALYZE_UNCACHED_TEMP_VIEW", + condition = "UNSUPPORTED_FEATURE.ANALYZE_UNCACHED_TEMP_VIEW", parameters = Map("viewName" -> "`global_temp`.`gTempView`") ) @@ -775,7 +775,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared exception = intercept[AnalysisException] { sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS value, name, $dupCol") }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`value`")) } } @@ -849,7 +849,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared sql(s"ANALYZE TABLES IN db_not_exists COMPUTE STATISTICS") } checkError(e, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> "`spark_catalog`.`db_not_exists`")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala index ef8b66566f246..7fa29dd38fd96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala @@ -366,7 +366,7 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats assert(stats.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") - if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { + if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION.key) == "hive") { sql("CREATE TABLE hive_tbl(i int) PARTITIONED BY (j int)") sql("INSERT INTO hive_tbl PARTITION(j=1) SELECT 1") val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats @@ -381,7 +381,7 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils // Test data source table checkStatsConversion(tableName = "ds_tbl", isDatasourceTable = true) // Test hive serde table - if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { + if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION.key) == "hive") { checkStatsConversion(tableName = "hive_tbl", isDatasourceTable = false) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 523b3518db48c..ec240d71b851f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -714,6 +714,34 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { df.select(sentences($"str", $"language", $"country")), Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now")))) + checkAnswer( + df.selectExpr("sentences(str, language)"), + Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now")))) + + checkAnswer( + df.select(sentences($"str", $"language")), + Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now")))) + + checkAnswer( + df.selectExpr("sentences(str)"), + Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now")))) + + checkAnswer( + df.select(sentences($"str")), + Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now")))) + + checkAnswer( + df.selectExpr("sentences(str, null, null)"), + Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now")))) + + checkAnswer( + df.selectExpr("sentences(str, '', null)"), + Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now")))) + + checkAnswer( + df.selectExpr("sentences(str, null)"), + Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now")))) + // Type coercion checkAnswer( df.selectExpr("sentences(null)", "sentences(10)", "sentences(3.14)"), @@ -727,7 +755,7 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("sentences()") }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> toSQLId("sentences"), "expectedNum" -> "[1, 2, 3]", @@ -828,7 +856,7 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("select regexp_replace(collect_list(1), '1', '2')") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = None, parameters = Map( "sqlExpr" -> "\"regexp_replace(collect_list(1), 1, 2, 1)\"", @@ -848,7 +876,7 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkRuntimeException] { sql("select regexp_replace('', '[a\\\\d]{0, 2}', 'x')").collect() }, - errorClass = "INVALID_PARAMETER_VALUE.PATTERN", + condition = "INVALID_PARAMETER_VALUE.PATTERN", parameters = Map( "parameter" -> toSQLId("regexp"), "functionName" -> toSQLId("regexp_replace"), @@ -859,7 +887,7 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkRuntimeException] { sql("select regexp_extract('', '[a\\\\d]{0, 2}', 1)").collect() }, - errorClass = "INVALID_PARAMETER_VALUE.PATTERN", + condition = "INVALID_PARAMETER_VALUE.PATTERN", parameters = Map( "parameter" -> toSQLId("regexp"), "functionName" -> toSQLId("regexp_extract"), @@ -870,7 +898,7 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkRuntimeException] { sql("select rlike('', '[a\\\\d]{0, 2}')").collect() }, - errorClass = "INVALID_PARAMETER_VALUE.PATTERN", + condition = "INVALID_PARAMETER_VALUE.PATTERN", parameters = Map( "parameter" -> toSQLId("regexp"), "functionName" -> toSQLId("rlike"), @@ -920,7 +948,7 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df2.select(func(col("input"), col("format"))).collect() }, - errorClass = "NON_FOLDABLE_ARGUMENT", + condition = "NON_FOLDABLE_ARGUMENT", parameters = Map( "funcName" -> s"`$funcName`", "paramName" -> "`format`", @@ -932,7 +960,7 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df2.select(func(col("input"), lit("invalid_format"))).collect() }, - errorClass = "INVALID_PARAMETER_VALUE.BINARY_FORMAT", + condition = "INVALID_PARAMETER_VALUE.BINARY_FORMAT", parameters = Map( "parameter" -> "`format`", "functionName" -> s"`$funcName`", @@ -944,7 +972,7 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(s"select $funcName('a', 'b', 'c')") }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> s"`$funcName`", "expectedNum" -> "2", @@ -955,7 +983,7 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(s"select $funcName(x'537061726b2053514c', CAST(NULL AS STRING))") }, - errorClass = "INVALID_PARAMETER_VALUE.NULL", + condition = "INVALID_PARAMETER_VALUE.NULL", parameters = Map( "functionName" -> s"`$funcName`", "parameter" -> "`format`"), @@ -1058,7 +1086,7 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df1.select(like(col("a"), col("b"), lit(618))).collect() }, - errorClass = "INVALID_ESCAPE_CHAR", + condition = "INVALID_ESCAPE_CHAR", parameters = Map("sqlExpr" -> "\"618\""), context = ExpectedContext("like", getCurrentClassCallSitePattern) ) @@ -1067,7 +1095,7 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df1.select(ilike(col("a"), col("b"), lit(618))).collect() }, - errorClass = "INVALID_ESCAPE_CHAR", + condition = "INVALID_ESCAPE_CHAR", parameters = Map("sqlExpr" -> "\"618\""), context = ExpectedContext("ilike", getCurrentClassCallSitePattern) ) @@ -1078,7 +1106,7 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df1.select(like(col("a"), col("b"), lit("中国"))).collect() }, - errorClass = "INVALID_ESCAPE_CHAR", + condition = "INVALID_ESCAPE_CHAR", parameters = Map("sqlExpr" -> "\"中国\""), context = ExpectedContext("like", getCurrentClassCallSitePattern) ) @@ -1087,7 +1115,7 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df1.select(ilike(col("a"), col("b"), lit("中国"))).collect() }, - errorClass = "INVALID_ESCAPE_CHAR", + condition = "INVALID_ESCAPE_CHAR", parameters = Map("sqlExpr" -> "\"中国\""), context = ExpectedContext("ilike", getCurrentClassCallSitePattern) ) @@ -1282,7 +1310,7 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { intercept[SparkRuntimeException](df.queryExecution.optimizedPlan) checkError( exception = intercept[SparkRuntimeException](df.queryExecution.explainString(FormattedMode)), - errorClass = "INVALID_PARAMETER_VALUE.PATTERN", + condition = "INVALID_PARAMETER_VALUE.PATTERN", parameters = Map( "parameter" -> toSQLId("regexp"), "functionName" -> toSQLId("regexp_replace"), @@ -1310,7 +1338,7 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(s"select concat_ws(',', collect_list(dat)) FROM $testTable") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> """"concat_ws(,, collect_list(dat))"""", "paramIndex" -> "second", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 68f14f13bbd66..23c4d51983bb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -533,7 +533,7 @@ class SubquerySuite extends QueryTest } checkError( exception, - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + condition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + "NON_CORRELATED_COLUMNS_IN_GROUP_BY", parameters = Map("value" -> "c2"), sqlState = None, @@ -548,7 +548,7 @@ class SubquerySuite extends QueryTest } checkError( exception1, - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + condition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + "MUST_AGGREGATE_CORRELATED_SCALAR_SUBQUERY", parameters = Map.empty, context = ExpectedContext( @@ -558,7 +558,7 @@ class SubquerySuite extends QueryTest } checkErrorMatchPVals( exception2, - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + condition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + "MUST_AGGREGATE_CORRELATED_SCALAR_SUBQUERY", parameters = Map.empty[String, String], sqlState = None, @@ -850,7 +850,7 @@ class SubquerySuite extends QueryTest } checkErrorMatchPVals( exception1, - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + condition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + "ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED", parameters = Map("treeNode" -> "(?s).*"), sqlState = None, @@ -872,7 +872,7 @@ class SubquerySuite extends QueryTest } checkErrorMatchPVals( exception2, - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + condition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + "ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED", parameters = Map("treeNode" -> "(?s).*"), sqlState = None, @@ -893,7 +893,7 @@ class SubquerySuite extends QueryTest } checkErrorMatchPVals( exception3, - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + condition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + "ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED", parameters = Map("treeNode" -> "(?s).*"), sqlState = None, @@ -1057,7 +1057,7 @@ class SubquerySuite extends QueryTest } checkError( exception1, - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_REFERENCE", + condition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_REFERENCE", parameters = Map("sqlExprs" -> "\"explode(arr_c2)\""), context = ExpectedContext( fragment = "LATERAL VIEW explode(t2.arr_c2) q AS c2", @@ -1098,7 +1098,7 @@ class SubquerySuite extends QueryTest checkError( exception = intercept[AnalysisException](sql(query)), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map( "objectName" -> "`a`", @@ -2552,7 +2552,7 @@ class SubquerySuite extends QueryTest |""".stripMargin ).collect() }, - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + condition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + "UNSUPPORTED_CORRELATED_REFERENCE_DATA_TYPE", parameters = Map("expr" -> "v1.x", "dataType" -> "map"), context = ExpectedContext( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TableOptionsConstantFoldingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TableOptionsConstantFoldingSuite.scala index 2e56327a63136..aa82ac57089f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TableOptionsConstantFoldingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TableOptionsConstantFoldingSuite.scala @@ -70,42 +70,42 @@ class TableOptionsConstantFoldingSuite extends QueryTest with SharedSparkSession checkError( exception = intercept[AnalysisException]( sql(s"$prefix ('k' = 1 + 2 + unresolvedAttribute)")), - errorClass = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", parameters = Map( "objectName" -> "`unresolvedAttribute`"), queryContext = Array(ExpectedContext("", "", 60, 78, "unresolvedAttribute"))) checkError( exception = intercept[AnalysisException]( sql(s"$prefix ('k' = true or false or unresolvedAttribute)")), - errorClass = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", parameters = Map( "objectName" -> "`unresolvedAttribute`"), queryContext = Array(ExpectedContext("", "", 69, 87, "unresolvedAttribute"))) checkError( exception = intercept[AnalysisException]( sql(s"$prefix ('k' = cast(array('9', '9') as array))")), - errorClass = "INVALID_SQL_SYNTAX.OPTION_IS_INVALID", + condition = "INVALID_SQL_SYNTAX.OPTION_IS_INVALID", parameters = Map( "key" -> "k", "supported" -> "constant expressions")) checkError( exception = intercept[AnalysisException]( sql(s"$prefix ('k' = cast(map('9', '9') as map))")), - errorClass = "INVALID_SQL_SYNTAX.OPTION_IS_INVALID", + condition = "INVALID_SQL_SYNTAX.OPTION_IS_INVALID", parameters = Map( "key" -> "k", "supported" -> "constant expressions")) checkError( exception = intercept[AnalysisException]( sql(s"$prefix ('k' = raise_error('failure'))")), - errorClass = "INVALID_SQL_SYNTAX.OPTION_IS_INVALID", + condition = "INVALID_SQL_SYNTAX.OPTION_IS_INVALID", parameters = Map( "key" -> "k", "supported" -> "constant expressions")) checkError( exception = intercept[AnalysisException]( sql(s"$prefix ('k' = raise_error('failure'))")), - errorClass = "INVALID_SQL_SYNTAX.OPTION_IS_INVALID", + condition = "INVALID_SQL_SYNTAX.OPTION_IS_INVALID", parameters = Map( "key" -> "k", "supported" -> "constant expressions")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 36552d5c5487c..2e072e5afc926 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -125,7 +125,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.selectExpr("substr('abcd', 2, 3, 4)") }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> toSQLId("substr"), "expectedNum" -> "[2, 3]", @@ -146,7 +146,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { spark.udf.register("foo", (_: String).length) df.selectExpr("foo(2, 3, 4)") }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> toSQLId("foo"), "expectedNum" -> "1", @@ -166,7 +166,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { spark.emptyDataFrame.selectExpr(sqlText) }, - errorClass = "UNRESOLVED_ROUTINE", + condition = "UNRESOLVED_ROUTINE", parameters = Map( "routineName" -> "`a_function_that_does_not_exist`", "searchPath" -> "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`default`]"), @@ -772,7 +772,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { checkError( exception = intercept[AnalysisException](df.select(myUdf(Column("col")))), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`b`", "proposal" -> "`a`"), @@ -1206,7 +1206,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { ) checkError( intercept[AnalysisException](spark.range(1).select(f())), - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_ENCODER", + condition = "UNSUPPORTED_DATA_TYPE_FOR_ENCODER", sqlState = "0A000", parameters = Map("dataType" -> s"\"${dt.sql}\"") ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UrlFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UrlFunctionsSuite.scala index c89ddd0e6a1f1..428065fb6986f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UrlFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UrlFunctionsSuite.scala @@ -76,7 +76,7 @@ class UrlFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkIllegalArgumentException] { sql(s"SELECT parse_url('$url', 'HOST')").collect() }, - errorClass = "INVALID_URL", + condition = "INVALID_URL", parameters = Map( "url" -> url, "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala index 4a20ec4af7e65..19d4ac23709b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala @@ -16,7 +16,11 @@ */ package org.apache.spark.sql +import org.apache.spark.SparkThrowable +import org.apache.spark.sql.QueryTest.sameRows +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} +import org.apache.spark.sql.catalyst.expressions.variant.{ToVariantObject, VariantExpressionEvalUtils} import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.functions._ @@ -25,6 +29,7 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarArray import org.apache.spark.types.variant.VariantBuilder +import org.apache.spark.types.variant.VariantUtil._ import org.apache.spark.unsafe.types.VariantVal class VariantEndToEndSuite extends QueryTest with SharedSparkSession { @@ -34,8 +39,10 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { def check(input: String, output: String = null): Unit = { val df = Seq(input).toDF("v") val variantDF = df.select(to_json(parse_json(col("v")))) + val variantDF2 = df.select(to_json(from_json(col("v"), VariantType))) val expected = if (output != null) output else input checkAnswer(variantDF, Seq(Row(expected))) + checkAnswer(variantDF2, Seq(Row(expected))) } check("null") @@ -158,6 +165,34 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { checkAnswer(variantDF, Seq(Row(expected))) } + test("to_variant_object - Codegen Support") { + Seq("CODEGEN_ONLY", "NO_CODEGEN").foreach { codegenMode => + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + val schema = StructType(Array( + StructField("v", StructType(Array(StructField("a", IntegerType)))) + )) + val data = Seq(Row(Row(1)), Row(Row(2)), Row(Row(3)), Row(null)) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + val variantDF = df.select(to_variant_object(col("v"))) + val plan = variantDF.queryExecution.executedPlan + assert(plan.isInstanceOf[WholeStageCodegenExec] == (codegenMode == "CODEGEN_ONLY")) + val v1 = VariantExpressionEvalUtils.castToVariant(InternalRow(1), + StructType(Array(StructField("a", IntegerType)))) + val v2 = VariantExpressionEvalUtils.castToVariant(InternalRow(2), + StructType(Array(StructField("a", IntegerType)))) + val v3 = VariantExpressionEvalUtils.castToVariant(InternalRow(3), + StructType(Array(StructField("a", IntegerType)))) + val v4 = VariantExpressionEvalUtils.castToVariant(null, + StructType(Array(StructField("a", IntegerType)))) + val expected = Seq(Row(new VariantVal(v1.getValue, v1.getMetadata)), + Row(new VariantVal(v2.getValue, v2.getMetadata)), + Row(new VariantVal(v3.getValue, v3.getMetadata)), + Row(new VariantVal(v4.getValue, v4.getMetadata))) + sameRows(variantDF.collect().toSeq, expected) + } + } + } + test("schema_of_variant") { def check(json: String, expected: String): Unit = { val df = Seq(json).toDF("j").selectExpr("schema_of_variant(parse_json(j))") @@ -181,8 +216,8 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { check("1E0", "DOUBLE") check("true", "BOOLEAN") check("\"2000-01-01\"", "STRING") - check("""{"a":0}""", "STRUCT") - check("""{"b": {"c": "c"}, "a":["a"]}""", "STRUCT, b: STRUCT>") + check("""{"a":0}""", "OBJECT") + check("""{"b": {"c": "c"}, "a":["a"]}""", "OBJECT, b: OBJECT>") check("[]", "ARRAY") check("[false]", "ARRAY") check("[null, 1, 1.0]", "ARRAY") @@ -192,11 +227,11 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { check("[1.1, 11111111111111111111111111111111111111]", "ARRAY") check("[1, \"1\"]", "ARRAY") check("[{}, true]", "ARRAY") - check("""[{"c": ""}, {"a": null}, {"b": 1}]""", "ARRAY>") - check("""[{"a": ""}, {"a": null}, {"b": 1}]""", "ARRAY>") + check("""[{"c": ""}, {"a": null}, {"b": 1}]""", "ARRAY>") + check("""[{"a": ""}, {"a": null}, {"b": 1}]""", "ARRAY>") check( """[{"a": 1, "b": null}, {"b": true, "a": 1E0}]""", - "ARRAY>" + "ARRAY>" ) } @@ -233,7 +268,7 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { // Literal input. checkAnswer( sql("""SELECT schema_of_variant_agg(parse_json('{"a": [1, 2, 3]}'))"""), - Seq(Row("STRUCT>"))) + Seq(Row("OBJECT>"))) // Non-grouping aggregation. def checkNonGrouping(input: Seq[String], expected: String): Unit = { @@ -241,20 +276,20 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { Seq(Row(expected))) } - checkNonGrouping(Seq("""{"a": [1, 2, 3]}"""), "STRUCT>") - checkNonGrouping((0 to 100).map(i => s"""{"a": [$i]}"""), "STRUCT>") - checkNonGrouping(Seq("""[{"a": 1}, {"b": 2}]"""), "ARRAY>") - checkNonGrouping(Seq("""{"a": [1, 2, 3]}""", """{"a": "banana"}"""), "STRUCT") + checkNonGrouping(Seq("""{"a": [1, 2, 3]}"""), "OBJECT>") + checkNonGrouping((0 to 100).map(i => s"""{"a": [$i]}"""), "OBJECT>") + checkNonGrouping(Seq("""[{"a": 1}, {"b": 2}]"""), "ARRAY>") + checkNonGrouping(Seq("""{"a": [1, 2, 3]}""", """{"a": "banana"}"""), "OBJECT") checkNonGrouping(Seq("""{"a": "banana"}""", """{"b": "apple"}"""), - "STRUCT") - checkNonGrouping(Seq("""{"a": "data"}""", null), "STRUCT") + "OBJECT") + checkNonGrouping(Seq("""{"a": "data"}""", null), "OBJECT") checkNonGrouping(Seq(null, null), "VOID") - checkNonGrouping(Seq("""{"a": null}""", """{"a": null}"""), "STRUCT") + checkNonGrouping(Seq("""{"a": null}""", """{"a": null}"""), "OBJECT") checkNonGrouping(Seq( """{"hi":[]}""", """{"hi":[{},{}]}""", """{"hi":[{"it's":[{"me":[{"a": 1}]}]}]}"""), - "STRUCT>>>>>>") + "OBJECT>>>>>>") // Grouping aggregation. withView("v") { @@ -263,11 +298,11 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { (id, json) }.toDF("id", "json").createTempView("v") checkAnswer(sql("select schema_of_variant_agg(parse_json(json)) from v group by id % 2"), - Seq(Row("STRUCT>"), Row("STRUCT>"))) + Seq(Row("OBJECT>"), Row("OBJECT>"))) checkAnswer(sql("select schema_of_variant_agg(parse_json(json)) from v group by id % 3"), - Seq.fill(3)(Row("STRUCT>"))) + Seq.fill(3)(Row("OBJECT>"))) checkAnswer(sql("select schema_of_variant_agg(parse_json(json)) from v group by id % 4"), - Seq.fill(3)(Row("STRUCT>")) ++ Seq(Row("STRUCT>"))) + Seq.fill(3)(Row("OBJECT>")) ++ Seq(Row("OBJECT>"))) } } @@ -279,22 +314,61 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { dataVector.appendLong(456) val array = new ColumnarArray(dataVector, 0, 4) val variant = Cast(Literal(array, ArrayType(LongType)), VariantType).eval() + val variant2 = ToVariantObject(Literal(array, ArrayType(LongType))).eval() assert(variant.toString == "[null,123,null,456]") + assert(variant2.toString == "[null,123,null,456]") dataVector.close() } - test("cast to variant with scan input") { - withTempPath { dir => - val path = dir.getAbsolutePath - val input = Seq(Row(Array(1, null), Map("k1" -> null, "k2" -> false), Row(null, "str"))) - val schema = StructType.fromDDL( - "a array, m map, s struct") - spark.createDataFrame(spark.sparkContext.parallelize(input), schema).write.parquet(path) - val df = spark.read.parquet(path).selectExpr( - s"cast(cast(a as variant) as ${schema(0).dataType.sql})", - s"cast(cast(m as variant) as ${schema(1).dataType.sql})", - s"cast(cast(s as variant) as ${schema(2).dataType.sql})") - checkAnswer(df, input) + test("cast to variant/to_variant_object with scan input") { + Seq("NO_CODEGEN", "CODEGEN_ONLY").foreach { codegenMode => + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + withTempPath { dir => + val path = dir.getAbsolutePath + val input = Seq( + Row(Array(1, null), Map("k1" -> null, "k2" -> false), Row(null, "str")), + Row(null, null, null) + ) + val schema = StructType.fromDDL( + "a array, m map, s struct") + spark.createDataFrame(spark.sparkContext.parallelize(input), schema).write.parquet(path) + val df = spark.read.parquet(path).selectExpr( + s"cast(cast(a as variant) as ${schema(0).dataType.sql})", + s"cast(to_variant_object(m) as ${schema(1).dataType.sql})", + s"cast(to_variant_object(s) as ${schema(2).dataType.sql})") + checkAnswer(df, input) + val plan = df.queryExecution.executedPlan + assert(plan.isInstanceOf[WholeStageCodegenExec] == (codegenMode == "CODEGEN_ONLY")) + } + } + } + } + + test("from_json(_, 'variant') with duplicate keys") { + val json: String = """{"a": 1, "b": 2, "c": "3", "a": 4}""" + withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "true") { + val df = Seq(json).toDF("j") + .selectExpr("from_json(j,'variant')") + val actual = df.collect().head(0).asInstanceOf[VariantVal] + val expectedValue: Array[Byte] = Array(objectHeader(false, 1, 1), + /* size */ 3, + /* id list */ 0, 1, 2, + /* offset list */ 4, 0, 2, 6, + /* field data */ primitiveHeader(INT1), 2, shortStrHeader(1), '3', + primitiveHeader(INT1), 4) + val expectedMetadata: Array[Byte] = Array(VERSION, 3, 0, 1, 2, 3, 'a', 'b', 'c') + assert(actual === new VariantVal(expectedValue, expectedMetadata)) + } + withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "false") { + val df = Seq(json).toDF("j") + .selectExpr("from_json(j,'variant')") + checkError( + exception = intercept[SparkThrowable] { + df.collect() + }, + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + parameters = Map("badRecord" -> json, "failFastMode" -> "FAILFAST") + ) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala index 0c8b0b501951f..5d59a3e0f8256 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala @@ -87,8 +87,8 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval def rows(results: Any*): Seq[Row] = results.map(Row(_)) checkAnswer(df.select(is_variant_null(v)), rows(false, false)) - checkAnswer(df.select(schema_of_variant(v)), rows("STRUCT", "STRUCT")) - checkAnswer(df.select(schema_of_variant_agg(v)), rows("STRUCT")) + checkAnswer(df.select(schema_of_variant(v)), rows("OBJECT", "OBJECT")) + checkAnswer(df.select(schema_of_variant_agg(v)), rows("OBJECT")) checkAnswer(df.select(variant_get(v, "$.a", "int")), rows(1, null)) checkAnswer(df.select(variant_get(v, "$.b", "int")), rows(null, 2)) @@ -97,7 +97,7 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval exception = intercept[SparkRuntimeException] { df.select(variant_get(v, "$.a", "binary")).collect() }, - errorClass = "INVALID_VARIANT_CAST", + condition = "INVALID_VARIANT_CAST", parameters = Map("value" -> "1", "dataType" -> "\"BINARY\"") ) @@ -223,7 +223,7 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval exception = intercept[AnalysisException] { query.write.partitionBy("v").parquet(tempDir) }, - errorClass = "INVALID_PARTITION_COLUMN_DATA_TYPE", + condition = "INVALID_PARTITION_COLUMN_DATA_TYPE", parameters = Map("type" -> "\"VARIANT\"") ) } @@ -239,7 +239,7 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval exception = intercept[AnalysisException] { query.write.partitionBy("v").saveAsTable("t") }, - errorClass = "INVALID_PARTITION_COLUMN_DATA_TYPE", + condition = "INVALID_PARTITION_COLUMN_DATA_TYPE", parameters = Map("type" -> "\"VARIANT\"") ) } @@ -255,7 +255,7 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval exception = intercept[AnalysisException] { spark.sql(s"CREATE TABLE t USING PARQUET PARTITIONED BY (v) AS $queryString") }, - errorClass = "INVALID_PARTITION_COLUMN_DATA_TYPE", + condition = "INVALID_PARTITION_COLUMN_DATA_TYPE", parameters = Map("type" -> "\"VARIANT\"") ) } @@ -290,7 +290,7 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval (s"named_struct('value', $v, 'metadata', cast(null as binary))", "INVALID_VARIANT_FROM_PARQUET.NULLABLE_OR_NOT_BINARY_FIELD", Map("field" -> "metadata")) ) - cases.foreach { case (structDef, errorClass, parameters) => + cases.foreach { case (structDef, condition, parameters) => Seq(false, true).foreach { vectorizedReader => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorizedReader.toString) { withTempDir { dir => @@ -302,7 +302,7 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval val e = intercept[org.apache.spark.SparkException](result.collect()) checkError( exception = e.getCause.asInstanceOf[AnalysisException], - errorClass = errorClass, + condition = condition, parameters = parameters ) } @@ -346,7 +346,7 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval exception = intercept[AnalysisException] { spark.read.format("json").option("singleVariantColumn", "var").schema("var variant") }, - errorClass = "INVALID_SINGLE_VARIANT_COLUMN", + condition = "INVALID_SINGLE_VARIANT_COLUMN", parameters = Map.empty ) checkError( @@ -354,7 +354,7 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval spark.read.format("json").option("singleVariantColumn", "another_name") .schema("var variant").json(file.getAbsolutePath).collect() }, - errorClass = "INVALID_SINGLE_VARIANT_COLUMN", + condition = "INVALID_SINGLE_VARIANT_COLUMN", parameters = Map.empty ) } @@ -422,7 +422,7 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval exception = intercept[AnalysisException] { spark.sql("select parse_json('') group by 1") }, - errorClass = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE", + condition = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE", parameters = Map("sqlExpr" -> "\"parse_json()\"", "dataType" -> "\"VARIANT\""), context = ExpectedContext(fragment = "parse_json('')", start = 7, stop = 20) ) @@ -431,7 +431,7 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval exception = intercept[AnalysisException] { spark.sql("select parse_json('') order by 1") }, - errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + condition = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", parameters = Map( "functionName" -> "`sortorder`", "dataType" -> "\"VARIANT\"", @@ -443,7 +443,7 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval exception = intercept[AnalysisException] { spark.sql("select parse_json('') sort by 1") }, - errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + condition = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", parameters = Map( "functionName" -> "`sortorder`", "dataType" -> "\"VARIANT\"", @@ -456,7 +456,7 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval spark.sql("with t as (select 1 as a, parse_json('') as v) " + "select rank() over (partition by a order by v) from t") }, - errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + condition = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", parameters = Map( "functionName" -> "`sortorder`", "dataType" -> "\"VARIANT\"", @@ -469,7 +469,7 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval spark.sql("with t as (select parse_json('') as v) " + "select t1.v from t as t1 join t as t2 on t1.v = t2.v") }, - errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + condition = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", parameters = Map( "functionName" -> "`=`", "dataType" -> "\"VARIANT\"", @@ -806,4 +806,11 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval checkSize(structResult.getAs[VariantVal](0), 5, 10, 5, 10) checkSize(structResult.getAs[VariantVal](1), 2, 4, 2, 4) } + + test("schema_of_variant(object)") { + for (expr <- Seq("schema_of_variant", "schema_of_variant_agg")) { + val q = s"""select $expr(parse_json('{"STRUCT": {"!special!": true}}'))""" + checkAnswer(sql(q), Row("""OBJECT>""")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala index 4169d53e4fc8e..f9d003572a229 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala @@ -126,7 +126,7 @@ class XmlFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq("1").toDS().select(from_xml($"value", lit("ARRAY"), Map[String, String]().asJava)) }, - errorClass = "INVALID_SCHEMA.NON_STRUCT_TYPE", + condition = "INVALID_SCHEMA.NON_STRUCT_TYPE", parameters = Map( "inputSchema" -> "\"ARRAY\"", "dataType" -> "\"ARRAY\"" @@ -138,7 +138,7 @@ class XmlFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq("1").toDF("xml").selectExpr(s"from_xml(xml, 'ARRAY')") }, - errorClass = "INVALID_SCHEMA.NON_STRUCT_TYPE", + condition = "INVALID_SCHEMA.NON_STRUCT_TYPE", parameters = Map( "inputSchema" -> "\"ARRAY\"", "dataType" -> "\"ARRAY\"" @@ -285,7 +285,7 @@ class XmlFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df2.selectExpr("to_xml(a, named_struct('a', 1))") }, - errorClass = "INVALID_OPTIONS.NON_MAP_FUNCTION", + condition = "INVALID_OPTIONS.NON_MAP_FUNCTION", parameters = Map.empty, context = ExpectedContext( fragment = "to_xml(a, named_struct('a', 1))", @@ -298,7 +298,7 @@ class XmlFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df2.selectExpr("to_xml(a, map('a', 1))") }, - errorClass = "INVALID_OPTIONS.NON_STRING_TYPE", + condition = "INVALID_OPTIONS.NON_STRING_TYPE", parameters = Map("mapType" -> "\"MAP\""), context = ExpectedContext( fragment = "to_xml(a, map('a', 1))", @@ -350,7 +350,7 @@ class XmlFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df3.selectExpr("from_xml(value, 1)") }, - errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", + condition = "INVALID_SCHEMA.NON_STRING_LITERAL", parameters = Map("inputSchema" -> "\"1\""), context = ExpectedContext( fragment = "from_xml(value, 1)", @@ -362,7 +362,7 @@ class XmlFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df3.selectExpr("""from_xml(value, 'time InvalidType')""") }, - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", sqlState = "42601", parameters = Map( "error" -> "'InvalidType'", @@ -378,7 +378,7 @@ class XmlFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df3.selectExpr("from_xml(value, 'time Timestamp', named_struct('a', 1))") }, - errorClass = "INVALID_OPTIONS.NON_MAP_FUNCTION", + condition = "INVALID_OPTIONS.NON_MAP_FUNCTION", parameters = Map.empty, context = ExpectedContext( fragment = "from_xml(value, 'time Timestamp', named_struct('a', 1))", @@ -390,7 +390,7 @@ class XmlFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df3.selectExpr("from_xml(value, 'time Timestamp', map('a', 1))") }, - errorClass = "INVALID_OPTIONS.NON_STRING_TYPE", + condition = "INVALID_OPTIONS.NON_STRING_TYPE", parameters = Map("mapType" -> "\"MAP\""), context = ExpectedContext( fragment = "from_xml(value, 'time Timestamp', map('a', 1))", @@ -518,7 +518,7 @@ class XmlFunctionsSuite extends QueryTest with SharedSparkSession { Seq(("""1""", "i int")).toDF("xml", "schema") .select(from_xml($"xml", $"schema", options)).collect() }, - errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", + condition = "INVALID_SCHEMA.NON_STRING_LITERAL", parameters = Map("inputSchema" -> "\"schema\""), context = ExpectedContext(fragment = "from_xml", getCurrentClassCallSitePattern) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala index 66cccb497bc7f..8fcc300d5c254 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala @@ -375,4 +375,17 @@ class ArtifactManagerSuite extends SharedSparkSession { val msg = instance.getClass.getMethod("msg").invoke(instance) assert(msg == "Hello Talon! Nice to meet you!") } + + test("Support Windows style paths") { + withTempPath { path => + val stagingPath = path.toPath + Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8)) + val remotePath = Paths.get("windows\\abc.txt") + artifactManager.addArtifact(remotePath, stagingPath, None) + val file = ArtifactManager.artifactRootDirectory + .resolve(s"$sessionUUID/windows/abc.txt") + .toFile + assert(file.exists()) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/catalyst/expressions/ValidateExternalTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/catalyst/expressions/ValidateExternalTypeSuite.scala index 57b9e592f31b0..6e54c2a1942ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/catalyst/expressions/ValidateExternalTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/catalyst/expressions/ValidateExternalTypeSuite.scala @@ -32,7 +32,7 @@ class ValidateExternalTypeSuite extends QueryTest with SharedSparkSession { ) )), new StructType().add("f3", StringType)).show() }.getCause.asInstanceOf[SparkRuntimeException], - errorClass = "INVALID_EXTERNAL_TYPE", + condition = "INVALID_EXTERNAL_TYPE", parameters = Map( ("externalType", "[B"), ("type", "\"STRING\""), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index 5b0d61fb6d771..21aa57cc1eace 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -166,7 +166,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { sql(s"ALTER TABLE $t ADD COLUMN c string AFTER non_exist")) checkError( exception = e1, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`c`", "fields" -> "a, point, b") ) @@ -191,7 +191,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { sql(s"ALTER TABLE $t ADD COLUMN point.x2 int AFTER non_exist")) checkError( exception = e2, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`x2`", "fields" -> "y, x, z") ) } @@ -231,7 +231,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { sql(s"ALTER TABLE $t ADD COLUMNS (yy int AFTER xx, xx int)")) checkError( exception = e, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`yy`", "fields" -> "a, x, y, z, b, point") ) } @@ -372,7 +372,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql("alter table t add column s bigint default badvalue") }, - errorClass = "INVALID_DEFAULT_VALUE.UNRESOLVED_EXPRESSION", + condition = "INVALID_DEFAULT_VALUE.UNRESOLVED_EXPRESSION", parameters = Map( "statement" -> "ALTER TABLE", "colName" -> "`s`", @@ -383,7 +383,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql("alter table t alter column s set default badvalue") }, - errorClass = "INVALID_DEFAULT_VALUE.UNRESOLVED_EXPRESSION", + condition = "INVALID_DEFAULT_VALUE.UNRESOLVED_EXPRESSION", parameters = Map( "statement" -> "ALTER TABLE ALTER COLUMN", "colName" -> "`s`", @@ -437,7 +437,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`point`", @@ -475,7 +475,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $t ADD COLUMNS $field double") }, - errorClass = "FIELD_ALREADY_EXISTS", + condition = "FIELD_ALREADY_EXISTS", parameters = expectedParameters, context = ExpectedContext( fragment = s"ALTER TABLE $t ADD COLUMNS $field double", @@ -494,7 +494,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $t ADD COLUMNS (data string, data1 string, data string)") }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`data`")) } } @@ -507,7 +507,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $t ADD COLUMNS (point.z double, point.z double, point.xx double)") }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> toSQLId("point.z"))) } } @@ -538,7 +538,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "CANNOT_UPDATE_FIELD.INTERVAL_TYPE", + condition = "CANNOT_UPDATE_FIELD.INTERVAL_TYPE", parameters = Map( "table" -> s"${toSQLId(prependCatalogName(t))}", "fieldName" -> "`id`"), @@ -600,7 +600,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "CANNOT_UPDATE_FIELD.STRUCT_TYPE", + condition = "CANNOT_UPDATE_FIELD.STRUCT_TYPE", parameters = Map( "table" -> s"${toSQLId(prependCatalogName(t))}", "fieldName" -> "`point`"), @@ -631,7 +631,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "CANNOT_UPDATE_FIELD.ARRAY_TYPE", + condition = "CANNOT_UPDATE_FIELD.ARRAY_TYPE", parameters = Map( "table" -> s"${toSQLId(prependCatalogName(t))}", "fieldName" -> "`points`"), @@ -675,7 +675,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "CANNOT_UPDATE_FIELD.MAP_TYPE", + condition = "CANNOT_UPDATE_FIELD.MAP_TYPE", parameters = Map( "table" -> s"${toSQLId(prependCatalogName(t))}", "fieldName" -> "`m`"), @@ -772,7 +772,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`data`", @@ -791,7 +791,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`point`.`x`", @@ -809,7 +809,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", sqlState = None, parameters = Map( "originType" -> "\"INT\"", @@ -866,7 +866,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(sqlText1) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`non_exist`", @@ -896,7 +896,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(sqlText2) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`point`.`non_exist`", @@ -989,7 +989,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`data`", @@ -1008,7 +1008,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`point`.`x`", @@ -1110,7 +1110,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`data`", @@ -1129,7 +1129,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`point`.`x`", @@ -1177,7 +1177,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $t RENAME COLUMN $field TO $newName") }, - errorClass = "FIELD_ALREADY_EXISTS", + condition = "FIELD_ALREADY_EXISTS", parameters = Map( "op" -> "rename", "fieldNames" -> s"${toSQLId(expectedName)}", @@ -1282,7 +1282,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`data`", @@ -1306,7 +1306,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`point`.`x`", @@ -1392,7 +1392,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $t REPLACE COLUMNS (data string, data1 string, data string)") }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`data`")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala index 7bbb6485c273f..fe078c5ae4413 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -55,8 +55,7 @@ class DataSourceV2DataFrameSessionCatalogSuite "and a same-name temp view exist") { withTable("same_name") { withTempView("same_name") { - val format = spark.sessionState.conf.defaultDataSourceName - sql(s"CREATE TABLE same_name(id LONG) USING $format") + sql(s"CREATE TABLE same_name(id LONG) USING $v2Format") spark.range(10).createTempView("same_name") spark.range(20).write.format(v2Format).mode(SaveMode.Append).saveAsTable("same_name") checkAnswer(spark.table("same_name"), spark.range(10).toDF()) @@ -88,6 +87,15 @@ class DataSourceV2DataFrameSessionCatalogSuite assert(tableInfo.properties().get("provider") === v2Format) } } + + test("SPARK-49246: saveAsTable with v1 format") { + withTable("t") { + sql("CREATE TABLE t(c INT) USING csv") + val df = spark.range(10).toDF() + df.write.mode(SaveMode.Overwrite).format("csv").saveAsTable("t") + verifyTable("t", df) + } + } } class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index 7d48459a8a517..c1e8b70ffddbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -184,17 +184,17 @@ class DataSourceV2DataFrameSuite val v2Writer = df.writeTo("testcat.table_name") checkError( exception = intercept[AnalysisException](v2Writer.append()), - errorClass = "_LEGACY_ERROR_TEMP_1183", + condition = "_LEGACY_ERROR_TEMP_1183", parameters = Map.empty ) checkError( exception = intercept[AnalysisException](v2Writer.overwrite(df("i"))), - errorClass = "_LEGACY_ERROR_TEMP_1183", + condition = "_LEGACY_ERROR_TEMP_1183", parameters = Map.empty ) checkError( exception = intercept[AnalysisException](v2Writer.overwritePartitions()), - errorClass = "_LEGACY_ERROR_TEMP_1183", + condition = "_LEGACY_ERROR_TEMP_1183", parameters = Map.empty ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index 95bdb2543e376..d6599debd3b11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -145,7 +145,7 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { exception = intercept[AnalysisException]( sql("SELECT testcat.non_exist('abc')").collect() ), - errorClass = "UNRESOLVED_ROUTINE", + condition = "UNRESOLVED_ROUTINE", parameters = Map( "routineName" -> "`testcat`.`non_exist`", "searchPath" -> "[`system`.`builtin`, `system`.`session`, `testcat`.`default`]"), @@ -161,7 +161,7 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { exception = intercept[AnalysisException]( sql("SELECT testcat.strlen('abc')").collect() ), - errorClass = "_LEGACY_ERROR_TEMP_1184", + condition = "_LEGACY_ERROR_TEMP_1184", parameters = Map("plugin" -> "testcat", "ability" -> "functions") ) } @@ -174,7 +174,7 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { exception = intercept[AnalysisException] { sql("DESCRIBE FUNCTION testcat.abc") }, - errorClass = "_LEGACY_ERROR_TEMP_1184", + condition = "_LEGACY_ERROR_TEMP_1184", parameters = Map( "plugin" -> "testcat", "ability" -> "functions" @@ -185,7 +185,7 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { exception = intercept[AnalysisException] { sql("DESCRIBE FUNCTION default.ns1.ns2.fun") }, - errorClass = "REQUIRES_SINGLE_PART_NAMESPACE", + condition = "REQUIRES_SINGLE_PART_NAMESPACE", parameters = Map( "sessionCatalog" -> "spark_catalog", "namespace" -> "`default`.`ns1`.`ns2`") @@ -343,7 +343,7 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { checkError( exception = intercept[AnalysisException](sql("SELECT testcat.ns.strlen(42)")), - errorClass = "_LEGACY_ERROR_TEMP_1198", + condition = "_LEGACY_ERROR_TEMP_1198", parameters = Map( "unbound" -> "strlen", "arguments" -> "int", @@ -358,7 +358,7 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { checkError( exception = intercept[AnalysisException](sql("SELECT testcat.ns.strlen('a', 'b')")), - errorClass = "_LEGACY_ERROR_TEMP_1198", + condition = "_LEGACY_ERROR_TEMP_1198", parameters = Map( "unbound" -> "strlen", "arguments" -> "string, string", @@ -414,7 +414,7 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { new JavaStrLen(new JavaStrLenNoImpl)) checkError( exception = intercept[AnalysisException](sql("SELECT testcat.ns.strlen('abc')").collect()), - errorClass = "_LEGACY_ERROR_TEMP_3055", + condition = "_LEGACY_ERROR_TEMP_3055", parameters = Map("scalarFunc" -> "strlen"), context = ExpectedContext( fragment = "testcat.ns.strlen('abc')", @@ -429,7 +429,7 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenBadInputTypes)) checkError( exception = intercept[AnalysisException](sql("SELECT testcat.ns.strlen('abc')").collect()), - errorClass = "_LEGACY_ERROR_TEMP_1199", + condition = "_LEGACY_ERROR_TEMP_1199", parameters = Map( "bound" -> "strlen_bad_input_types", "argsLen" -> "1", @@ -448,7 +448,7 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { addFunction(Identifier.of(Array("ns"), "add"), new JavaLongAdd(new JavaLongAddMismatchMagic)) checkError( exception = intercept[AnalysisException](sql("SELECT testcat.ns.add(1L, 2L)").collect()), - errorClass = "_LEGACY_ERROR_TEMP_3055", + condition = "_LEGACY_ERROR_TEMP_3055", parameters = Map("scalarFunc" -> "long_add_mismatch_magic"), context = ExpectedContext( fragment = "testcat.ns.add(1L, 2L)", @@ -481,7 +481,7 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { exception = intercept[AnalysisException] { sql(sqlText).collect() }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = None, parameters = Map( "sqlExpr" -> ".*", @@ -539,7 +539,7 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { checkError( exception = intercept[AnalysisException]( sql("SELECT testcat.ns.strlen('abc')")), - errorClass = "INVALID_UDF_IMPLEMENTATION", + condition = "INVALID_UDF_IMPLEMENTATION", parameters = Map( "funcName" -> "`bad_bound_func`"), context = ExpectedContext( @@ -602,7 +602,7 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { Seq(1.toShort, 2.toShort).toDF("i").write.saveAsTable(t) checkError( exception = intercept[AnalysisException](sql(s"SELECT testcat.ns.avg(i) from $t")), - errorClass = "_LEGACY_ERROR_TEMP_1198", + condition = "_LEGACY_ERROR_TEMP_1198", parameters = Map( "unbound" -> "iavg", "arguments" -> "smallint", @@ -637,7 +637,7 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { sql("SELECT testcat.ns.avg(*) from values " + "(date '2021-06-01' - date '2011-06-01'), (date '2000-01-01' - date '1900-01-01')") }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"v2aggregator(col1)\"", "paramIndex" -> "first", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala index 95624f3f61c5c..7463eb34d17ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala @@ -71,4 +71,12 @@ class DataSourceV2SQLSessionCatalogSuite sql(s"CREATE EXTERNAL TABLE t (i INT) USING $v2Format TBLPROPERTIES($prop)") } } + + test("SPARK-49152: partition columns should be put at the end") { + withTable("t") { + sql("CREATE TABLE t (c1 INT, c2 INT) USING json PARTITIONED BY (c1)") + // partition columns should be put at the end. + assert(getTableMetadata("default.t").columns().map(_.name()) === Seq("c2", "c1")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 1d37c6aa4eb7f..7aaec6d500ba0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -132,7 +132,7 @@ class DataSourceV2SQLSuiteV1Filter checkError( exception = analysisException(s"DESCRIBE $t invalid_col"), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`invalid_col`", "proposal" -> "`id`, `data`"), @@ -165,7 +165,7 @@ class DataSourceV2SQLSuiteV1Filter sql(s"CREATE TABLE $t (d struct) USING foo") checkError( exception = analysisException(s"describe $t d.a"), - errorClass = "_LEGACY_ERROR_TEMP_1060", + condition = "_LEGACY_ERROR_TEMP_1060", parameters = Map( "command" -> "DESC TABLE COLUMN", "column" -> "d.a")) @@ -219,7 +219,7 @@ class DataSourceV2SQLSuiteV1Filter spark.sql("CREATE TABLE testcat.table_name " + "(id bigint, data string, id2 bigint) USING bar") }, - errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> "`table_name`")) // table should not have changed @@ -302,14 +302,14 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"$action TABLE table_name (id int, value interval) USING $v2Format") }, - errorClass = "_LEGACY_ERROR_TEMP_1183", + condition = "_LEGACY_ERROR_TEMP_1183", parameters = Map.empty) checkError( exception = intercept[AnalysisException] { sql(s"$action TABLE table_name (id array) USING $v2Format") }, - errorClass = "_LEGACY_ERROR_TEMP_1183", + condition = "_LEGACY_ERROR_TEMP_1183", parameters = Map.empty) } } @@ -321,14 +321,14 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"$action TABLE table_name USING $v2Format as select interval 1 day") }, - errorClass = "_LEGACY_ERROR_TEMP_1183", + condition = "_LEGACY_ERROR_TEMP_1183", parameters = Map.empty) checkError( exception = intercept[AnalysisException] { sql(s"$action TABLE table_name USING $v2Format as select array(interval 1 day)") }, - errorClass = "_LEGACY_ERROR_TEMP_1183", + condition = "_LEGACY_ERROR_TEMP_1183", parameters = Map.empty) } } @@ -662,7 +662,7 @@ class DataSourceV2SQLSuiteV1Filter spark.sql(s"REPLACE TABLE $catalog.replaced USING $v2Source " + s"AS SELECT id, data FROM source") }, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`replaced`")) } } @@ -677,7 +677,7 @@ class DataSourceV2SQLSuiteV1Filter s" TBLPROPERTIES (`$SIMULATE_DROP_BEFORE_REPLACE_PROPERTY`=true)" + s" AS SELECT id, data FROM source") }, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`replaced`")) } @@ -720,7 +720,7 @@ class DataSourceV2SQLSuiteV1Filter spark.sql("CREATE TABLE testcat.table_name USING bar AS " + "SELECT id, data, id as id2 FROM source2") }, - errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> "`table_name`")) // table should not have changed @@ -1072,7 +1072,7 @@ class DataSourceV2SQLSuiteV1Filter checkError( exception = analysisException(s"SELECT ns1.ns2.ns3.tbl.id from $t"), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`ns1`.`ns2`.`ns3`.`tbl`.`id`", "proposal" -> "`testcat`.`ns1`.`ns2`.`tbl`.`id`, `testcat`.`ns1`.`ns2`.`tbl`.`point`"), @@ -1135,7 +1135,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"INSERT INTO $t1 VALUES(4)") }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`tbl`", "tableColumns" -> "`id`, `data`", @@ -1147,7 +1147,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"INSERT INTO $t1(data, data) VALUES(5)") }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`data`")) } } @@ -1170,7 +1170,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"INSERT OVERWRITE $t1 VALUES(4)") }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`tbl`", "tableColumns" -> "`id`, `data`", @@ -1182,7 +1182,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)") }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`data`")) } } @@ -1206,7 +1206,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"INSERT OVERWRITE $t1 VALUES('a', 4)") }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`tbl`", "tableColumns" -> "`id`, `data`, `data2`", @@ -1218,7 +1218,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)") }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`data`")) } } @@ -1230,7 +1230,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql("INSERT OVERWRITE t PARTITION (c='1') (c) VALUES ('2')") }, - errorClass = "STATIC_PARTITION_COLUMN_IN_INSERT_COLUMN_LIST", + condition = "STATIC_PARTITION_COLUMN_IN_INSERT_COLUMN_LIST", parameters = Map("staticName" -> "c")) } } @@ -1240,7 +1240,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql("SHOW VIEWS FROM a.b") }, - errorClass = "_LEGACY_ERROR_TEMP_1126", + condition = "_LEGACY_ERROR_TEMP_1126", parameters = Map("catalog" -> "a.b")) } @@ -1249,7 +1249,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql("SHOW VIEWS FROM testcat") }, - errorClass = "_LEGACY_ERROR_TEMP_1184", + condition = "_LEGACY_ERROR_TEMP_1184", parameters = Map("plugin" -> "testcat", "ability" -> "views")) } @@ -1271,7 +1271,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[ParseException] { sql(sqlText) }, - errorClass = "UNSUPPORTED_FEATURE.SET_TABLE_PROPERTY", + condition = "UNSUPPORTED_FEATURE.SET_TABLE_PROPERTY", parameters = Map( "property" -> key, "msg" -> keyParameters.getOrElse( @@ -1288,7 +1288,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[ParseException] { sql(sql1) }, - errorClass = "UNSUPPORTED_FEATURE.SET_TABLE_PROPERTY", + condition = "UNSUPPORTED_FEATURE.SET_TABLE_PROPERTY", parameters = Map( "property" -> key, "msg" -> keyParameters.getOrElse( @@ -1303,7 +1303,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[ParseException] { sql(sql2) }, - errorClass = "UNSUPPORTED_FEATURE.SET_TABLE_PROPERTY", + condition = "UNSUPPORTED_FEATURE.SET_TABLE_PROPERTY", parameters = Map( "property" -> key, "msg" -> keyParameters.getOrElse( @@ -1348,7 +1348,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[ParseException] { sql(sql1) }, - errorClass = "_LEGACY_ERROR_TEMP_0032", + condition = "_LEGACY_ERROR_TEMP_0032", parameters = Map("pathOne" -> "foo", "pathTwo" -> "bar"), context = ExpectedContext( fragment = sql1, @@ -1361,7 +1361,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[ParseException] { sql(sql2) }, - errorClass = "_LEGACY_ERROR_TEMP_0032", + condition = "_LEGACY_ERROR_TEMP_0032", parameters = Map("pathOne" -> "foo", "pathTwo" -> "bar"), context = ExpectedContext( fragment = sql2, @@ -1453,7 +1453,7 @@ class DataSourceV2SQLSuiteV1Filter sql("USE ns1") } checkError(exception, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> "`spark_catalog`.`ns1`")) } @@ -1464,7 +1464,7 @@ class DataSourceV2SQLSuiteV1Filter sql("USE testcat.ns1.ns2") } checkError(exception, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> "`testcat`.`ns1`.`ns2`")) } @@ -1503,7 +1503,7 @@ class DataSourceV2SQLSuiteV1Filter sql("USE dummy") sql(s"$statement dummy.$tableDefinition") }, - errorClass = "UNSUPPORTED_FEATURE.TABLE_OPERATION", + condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", parameters = Map( "tableName" -> "`dummy`.`my_tab`", "operation" -> "column default value" @@ -1535,7 +1535,7 @@ class DataSourceV2SQLSuiteV1Filter sql("USE dummy") sql(s"$statement dummy.$tableDefinition USING foo") }, - errorClass = "UNSUPPORTED_FEATURE.TABLE_OPERATION", + condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", parameters = Map( "tableName" -> "`dummy`.`my_tab`", "operation" -> "generated columns" @@ -1559,7 +1559,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"$statement testcat.$tableDefinition USING foo") }, - errorClass = "GENERATED_COLUMN_WITH_DEFAULT_VALUE", + condition = "GENERATED_COLUMN_WITH_DEFAULT_VALUE", parameters = Map( "colName" -> "eventYear", "defaultValue" -> "0", @@ -1584,7 +1584,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(customTableDef.getOrElse(tableDef)) }, - errorClass = "UNSUPPORTED_EXPRESSION_GENERATED_COLUMN", + condition = "UNSUPPORTED_EXPRESSION_GENERATED_COLUMN", parameters = Map( "fieldName" -> "b", "expressionStr" -> expr, @@ -1627,7 +1627,7 @@ class DataSourceV2SQLSuiteV1Filter sql(s"CREATE TABLE testcat.$tblName(a INT, " + "b INT GENERATED ALWAYS AS (B + 1)) USING foo") }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`B`", "proposal" -> "`a`"), context = ExpectedContext(fragment = "B", start = 0, stop = 0) ) @@ -1685,7 +1685,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"CREATE TABLE testcat.$tblName(a INT, b INT GENERATED ALWAYS AS (c + 1)) USING foo") }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`c`", "proposal" -> "`a`"), context = ExpectedContext(fragment = "c", start = 0, stop = 0) ) @@ -1753,6 +1753,64 @@ class DataSourceV2SQLSuiteV1Filter } } + test("SPARK-48824: Column cannot have both an identity column spec and a default value") { + val tblName = "my_tab" + val tableDefinition = + s"$tblName(id BIGINT GENERATED ALWAYS AS IDENTITY DEFAULT 0, name STRING)" + withSQLConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS.key -> "foo") { + for (statement <- Seq("CREATE TABLE", "REPLACE TABLE")) { + withTable(s"testcat.$tblName") { + if (statement == "REPLACE TABLE") { + sql(s"CREATE TABLE testcat.$tblName(a INT) USING foo") + } + checkError( + exception = intercept[AnalysisException] { + sql(s"$statement testcat.$tableDefinition USING foo") + }, + condition = "IDENTITY_COLUMN_WITH_DEFAULT_VALUE", + parameters = Map( + "colName" -> "id", + "defaultValue" -> "0", + "identityColumnSpec" -> + "IdentityColumnSpec{start=1, step=1, allowExplicitInsert=false}") + ) + } + } + } + } + + test("SPARK-48824: Identity columns only allowed with TableCatalogs that " + + "SUPPORTS_CREATE_TABLE_WITH_IDENTITY_COLUMNS") { + val tblName = "my_tab" + val tableDefinition = + s"$tblName(id BIGINT GENERATED ALWAYS AS IDENTITY(), val INT)" + for (statement <- Seq("CREATE TABLE", "REPLACE TABLE")) { + // InMemoryTableCatalog.capabilities() = {SUPPORTS_CREATE_TABLE_WITH_IDENTITY_COLUMNS} + withTable(s"testcat.$tblName") { + if (statement == "REPLACE TABLE") { + sql(s"CREATE TABLE testcat.$tblName(a INT) USING foo") + } + // Can create table with an identity column + sql(s"$statement testcat.$tableDefinition USING foo") + assert(catalog("testcat").asTableCatalog.tableExists(Identifier.of(Array(), tblName))) + } + // BasicInMemoryTableCatalog.capabilities() = {} + withSQLConf("spark.sql.catalog.dummy" -> classOf[BasicInMemoryTableCatalog].getName) { + checkError( + exception = intercept[AnalysisException] { + sql("USE dummy") + sql(s"$statement dummy.$tableDefinition USING foo") + }, + condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", + parameters = Map( + "tableName" -> "`dummy`.`my_tab`", + "operation" -> "identity column" + ) + ) + } + } + } + test("SPARK-46972: asymmetrical replacement for char/varchar in V2SessionCatalog.createTable") { // unset this config to use the default v2 session catalog. spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) @@ -1830,7 +1888,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(statement) }, - errorClass = "_LEGACY_ERROR_TEMP_3060", + condition = "_LEGACY_ERROR_TEMP_3060", parameters = Map( "i" -> i, "schema" -> @@ -1857,22 +1915,22 @@ class DataSourceV2SQLSuiteV1Filter withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { checkError( exception = analysisException(s"CREATE TABLE t ($c0 INT, $c1 INT) USING $v2Source"), - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${c0.toLowerCase(Locale.ROOT)}`")) checkError( exception = analysisException( s"CREATE TABLE testcat.t ($c0 INT, $c1 INT) USING $v2Source"), - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${c0.toLowerCase(Locale.ROOT)}`")) checkError( exception = analysisException( s"CREATE OR REPLACE TABLE t ($c0 INT, $c1 INT) USING $v2Source"), - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${c0.toLowerCase(Locale.ROOT)}`")) checkError( exception = analysisException( s"CREATE OR REPLACE TABLE testcat.t ($c0 INT, $c1 INT) USING $v2Source"), - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${c0.toLowerCase(Locale.ROOT)}`")) } } @@ -1884,23 +1942,23 @@ class DataSourceV2SQLSuiteV1Filter checkError( exception = analysisException( s"CREATE TABLE t (d struct<$c0: INT, $c1: INT>) USING $v2Source"), - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> toSQLId(s"d.${c0.toLowerCase(Locale.ROOT)}")) ) checkError( exception = analysisException( s"CREATE TABLE testcat.t (d struct<$c0: INT, $c1: INT>) USING $v2Source"), - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> toSQLId(s"d.${c0.toLowerCase(Locale.ROOT)}"))) checkError( exception = analysisException( s"CREATE OR REPLACE TABLE t (d struct<$c0: INT, $c1: INT>) USING $v2Source"), - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> toSQLId(s"d.${c0.toLowerCase(Locale.ROOT)}"))) checkError( exception = analysisException( s"CREATE OR REPLACE TABLE testcat.t (d struct<$c0: INT, $c1: INT>) USING $v2Source"), - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> toSQLId(s"d.${c0.toLowerCase(Locale.ROOT)}"))) } } @@ -1910,7 +1968,7 @@ class DataSourceV2SQLSuiteV1Filter checkError( exception = analysisException( s"CREATE TABLE tbl (a int, b string) USING $v2Source CLUSTERED BY (c) INTO 4 BUCKETS"), - errorClass = "_LEGACY_ERROR_TEMP_3060", + condition = "_LEGACY_ERROR_TEMP_3060", parameters = Map( "i" -> "c", "schema" -> @@ -1921,7 +1979,7 @@ class DataSourceV2SQLSuiteV1Filter checkError( exception = analysisException(s"CREATE TABLE testcat.tbl (a int, b string) " + s"USING $v2Source CLUSTERED BY (c) INTO 4 BUCKETS"), - errorClass = "_LEGACY_ERROR_TEMP_3060", + condition = "_LEGACY_ERROR_TEMP_3060", parameters = Map( "i" -> "c", "schema" -> @@ -1932,7 +1990,7 @@ class DataSourceV2SQLSuiteV1Filter checkError( exception = analysisException(s"CREATE OR REPLACE TABLE tbl (a int, b string) " + s"USING $v2Source CLUSTERED BY (c) INTO 4 BUCKETS"), - errorClass = "_LEGACY_ERROR_TEMP_3060", + condition = "_LEGACY_ERROR_TEMP_3060", parameters = Map( "i" -> "c", "schema" -> @@ -1943,7 +2001,7 @@ class DataSourceV2SQLSuiteV1Filter checkError( exception = analysisException(s"CREATE OR REPLACE TABLE testcat.tbl (a int, b string) " + s"USING $v2Source CLUSTERED BY (c) INTO 4 BUCKETS"), - errorClass = "_LEGACY_ERROR_TEMP_3060", + condition = "_LEGACY_ERROR_TEMP_3060", parameters = Map( "i" -> "c", "schema" -> @@ -1978,22 +2036,22 @@ class DataSourceV2SQLSuiteV1Filter checkError( exception = analysisException( s"CREATE TABLE t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)"), - errorClass = "_LEGACY_ERROR_TEMP_3058", + condition = "_LEGACY_ERROR_TEMP_3058", parameters = Map("checkType" -> "in the partitioning", "duplicateColumns" -> dupCol)) checkError( exception = analysisException( s"CREATE TABLE testcat.t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)"), - errorClass = "_LEGACY_ERROR_TEMP_3058", + condition = "_LEGACY_ERROR_TEMP_3058", parameters = Map("checkType" -> "in the partitioning", "duplicateColumns" -> dupCol)) checkError( exception = analysisException( s"CREATE OR REPLACE TABLE t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)"), - errorClass = "_LEGACY_ERROR_TEMP_3058", + condition = "_LEGACY_ERROR_TEMP_3058", parameters = Map("checkType" -> "in the partitioning", "duplicateColumns" -> dupCol)) checkError( exception = analysisException(s"CREATE OR REPLACE TABLE testcat.t ($c0 INT) " + s"USING $v2Source PARTITIONED BY ($c0, $c1)"), - errorClass = "_LEGACY_ERROR_TEMP_3058", + condition = "_LEGACY_ERROR_TEMP_3058", parameters = Map("checkType" -> "in the partitioning", "duplicateColumns" -> dupCol)) } } @@ -2007,26 +2065,26 @@ class DataSourceV2SQLSuiteV1Filter exception = analysisException( s"CREATE TABLE t ($c0 INT) USING $v2Source " + s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS"), - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map( "columnName" -> s"`${c0.toLowerCase(Locale.ROOT)}`")) checkError( exception = analysisException( s"CREATE TABLE testcat.t ($c0 INT) USING $v2Source " + s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS"), - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${c0.toLowerCase(Locale.ROOT)}`")) checkError( exception = analysisException( s"CREATE OR REPLACE TABLE t ($c0 INT) USING $v2Source " + s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS"), - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${c0.toLowerCase(Locale.ROOT)}`")) checkError( exception = analysisException( s"CREATE OR REPLACE TABLE testcat.t ($c0 INT) USING $v2Source " + s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS"), - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${c0.toLowerCase(Locale.ROOT)}`")) } } @@ -2125,10 +2183,18 @@ class DataSourceV2SQLSuiteV1Filter } test("REPLACE TABLE: v1 table") { - sql(s"CREATE OR REPLACE TABLE tbl (a int) USING ${classOf[SimpleScanSource].getName}") - val v2Catalog = catalog("spark_catalog").asTableCatalog - val table = v2Catalog.loadTable(Identifier.of(Array("default"), "tbl")) - assert(table.properties().get(TableCatalog.PROP_PROVIDER) == classOf[SimpleScanSource].getName) + val e = intercept[AnalysisException] { + sql(s"CREATE OR REPLACE TABLE tbl (a int) USING ${classOf[SimpleScanSource].getName}") + } + checkError( + exception = e, + condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", + sqlState = "0A000", + parameters = Map( + "tableName" -> "`spark_catalog`.`default`.`tbl`", + "operation" -> "REPLACE TABLE" + ) + ) } test("DeleteFrom: - delete with invalid predicate") { @@ -2140,7 +2206,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"DELETE FROM $t WHERE id = 2 AND id = id") }, - errorClass = "_LEGACY_ERROR_TEMP_1110", + condition = "_LEGACY_ERROR_TEMP_1110", parameters = Map( "table" -> "testcat.ns1.ns2.tbl", "filters" -> "[id = 2, id = id]")) @@ -2161,7 +2227,7 @@ class DataSourceV2SQLSuiteV1Filter // UPDATE non-existing table checkError( exception = analysisException("UPDATE dummy SET name='abc'"), - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`dummy`"), context = ExpectedContext( fragment = "dummy", @@ -2171,7 +2237,7 @@ class DataSourceV2SQLSuiteV1Filter // UPDATE non-existing column checkError( exception = analysisException(s"UPDATE $t SET dummy='abc'"), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`dummy`", "proposal" -> "`age`, `id`, `name`, `p`" @@ -2182,7 +2248,7 @@ class DataSourceV2SQLSuiteV1Filter stop = 41)) checkError( exception = analysisException(s"UPDATE $t SET name='abc' WHERE dummy=1"), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`dummy`", "proposal" -> "`age`, `id`, `name`, `p`" @@ -2197,7 +2263,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[SparkUnsupportedOperationException] { sql(s"UPDATE $t SET name='Robert', age=32 WHERE p=1") }, - errorClass = "_LEGACY_ERROR_TEMP_2096", + condition = "_LEGACY_ERROR_TEMP_2096", parameters = Map("ddl" -> "UPDATE TABLE") ) } @@ -2232,7 +2298,7 @@ class DataSourceV2SQLSuiteV1Filter |WHEN NOT MATCHED AND (target.col2='insert') |THEN INSERT * """.stripMargin), - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`testcat`.`ns1`.`ns2`.`dummy`"), context = ExpectedContext( fragment = "testcat.ns1.ns2.dummy", @@ -2252,7 +2318,7 @@ class DataSourceV2SQLSuiteV1Filter |WHEN NOT MATCHED AND (target.col2='insert') |THEN INSERT * """.stripMargin), - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`testcat`.`ns1`.`ns2`.`dummy`"), context = ExpectedContext( fragment = "testcat.ns1.ns2.dummy", @@ -2270,7 +2336,7 @@ class DataSourceV2SQLSuiteV1Filter |THEN INSERT *""".stripMargin checkError( exception = analysisException(sql1), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`target`.`dummy`", "proposal" -> "`age`, `id`, `name`, `p`"), @@ -2286,7 +2352,7 @@ class DataSourceV2SQLSuiteV1Filter |WHEN MATCHED AND (target.age > 10) THEN UPDATE SET target.age = source.dummy |WHEN NOT MATCHED AND (target.col2='insert') |THEN INSERT *""".stripMargin), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`source`.`dummy`", "proposal" -> "`age`, `age`, `id`, `id`, `name`, `name`, `p`, `p`"), @@ -2303,7 +2369,7 @@ class DataSourceV2SQLSuiteV1Filter |WHEN MATCHED AND (target.p > 0) THEN UPDATE SET * |WHEN NOT MATCHED THEN INSERT *""".stripMargin) }, - errorClass = "_LEGACY_ERROR_TEMP_2096", + condition = "_LEGACY_ERROR_TEMP_2096", parameters = Map("ddl" -> "MERGE INTO TABLE")) } } @@ -2316,7 +2382,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql("ALTER VIEW testcat.ns1.ns2.old RENAME TO ns1.new") }, - errorClass = "_LEGACY_ERROR_TEMP_1123", + condition = "_LEGACY_ERROR_TEMP_1123", parameters = Map.empty) } } @@ -2410,12 +2476,12 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"CREATE VIEW $v AS SELECT 1") }, - errorClass = "_LEGACY_ERROR_TEMP_1184", + condition = "_LEGACY_ERROR_TEMP_1184", parameters = Map("plugin" -> "testcat", "ability" -> "views")) } test("global temp view should not be masked by v2 catalog") { - val globalTempDB = spark.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE) + val globalTempDB = spark.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE.key) registerCatalog(globalTempDB, classOf[InMemoryTableCatalog]) try { @@ -2429,7 +2495,7 @@ class DataSourceV2SQLSuiteV1Filter } test("SPARK-30104: global temp db is used as a table name under v2 catalog") { - val globalTempDB = spark.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE) + val globalTempDB = spark.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE.key) val t = s"testcat.$globalTempDB" withTable(t) { sql(s"CREATE TABLE $t (id bigint, data string) USING foo") @@ -2440,7 +2506,7 @@ class DataSourceV2SQLSuiteV1Filter } test("SPARK-30104: v2 catalog named global_temp will be masked") { - val globalTempDB = spark.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE) + val globalTempDB = spark.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE.key) registerCatalog(globalTempDB, classOf[InMemoryTableCatalog]) checkError( exception = intercept[AnalysisException] { @@ -2448,7 +2514,7 @@ class DataSourceV2SQLSuiteV1Filter // the session catalog, not the `global_temp` v2 catalog. sql(s"CREATE TABLE $globalTempDB.ns1.ns2.tbl (id bigint, data string) USING json") }, - errorClass = "REQUIRES_SINGLE_PART_NAMESPACE", + condition = "REQUIRES_SINGLE_PART_NAMESPACE", parameters = Map( "sessionCatalog" -> "spark_catalog", "namespace" -> "`global_temp`.`ns1`.`ns2`")) @@ -2484,7 +2550,7 @@ class DataSourceV2SQLSuiteV1Filter def verify(sql: String): Unit = { checkError( exception = intercept[AnalysisException](spark.sql(sql)), - errorClass = "REQUIRES_SINGLE_PART_NAMESPACE", + condition = "REQUIRES_SINGLE_PART_NAMESPACE", parameters = Map("sessionCatalog" -> "spark_catalog", "namespace" -> "")) } @@ -2560,7 +2626,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"CREATE VIEW $sessionCatalogName.default.v AS SELECT * FROM t") }, - errorClass = "INVALID_TEMP_OBJ_REFERENCE", + condition = "INVALID_TEMP_OBJ_REFERENCE", parameters = Map( "obj" -> "VIEW", "objName" -> "`spark_catalog`.`default`.`v`", @@ -2592,7 +2658,7 @@ class DataSourceV2SQLSuiteV1Filter checkError( exception = intercept[AnalysisException](sql("COMMENT ON NAMESPACE abc IS NULL")), - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> "`spark_catalog`.`abc`")) // V2 non-session catalog is used. @@ -2602,7 +2668,7 @@ class DataSourceV2SQLSuiteV1Filter checkNamespaceComment("testcat.ns1", "NULL") checkError( exception = intercept[AnalysisException](sql("COMMENT ON NAMESPACE testcat.abc IS NULL")), - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> "`testcat`.`abc`")) } @@ -2628,7 +2694,7 @@ class DataSourceV2SQLSuiteV1Filter val sql1 = "COMMENT ON TABLE abc IS NULL" checkError( exception = intercept[AnalysisException](sql(sql1)), - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`abc`"), context = ExpectedContext(fragment = "abc", start = 17, stop = 19)) @@ -2642,17 +2708,17 @@ class DataSourceV2SQLSuiteV1Filter val sql2 = "COMMENT ON TABLE testcat.abc IS NULL" checkError( exception = intercept[AnalysisException](sql(sql2)), - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`testcat`.`abc`"), context = ExpectedContext(fragment = "testcat.abc", start = 17, stop = 27)) - val globalTempDB = spark.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE) + val globalTempDB = spark.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE.key) registerCatalog(globalTempDB, classOf[InMemoryTableCatalog]) withTempView("v") { sql("create global temp view v as select 1") checkError( exception = intercept[AnalysisException](sql("COMMENT ON TABLE global_temp.v IS NULL")), - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> "`global_temp`.`v`", "operation" -> "COMMENT ON TABLE"), @@ -2692,7 +2758,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"SELECT ns1.ns2.ns3.tbl.* from $t") }, - errorClass = "CANNOT_RESOLVE_STAR_EXPAND", + condition = "CANNOT_RESOLVE_STAR_EXPAND", parameters = Map( "targetString" -> "`ns1`.`ns2`.`ns3`.`tbl`", "columns" -> "`id`, `name`"), @@ -2756,7 +2822,7 @@ class DataSourceV2SQLSuiteV1Filter val e = intercept[AnalysisException](sql(sqlStatement)) checkError( e, - errorClass = "UNSUPPORTED_FEATURE.CATALOG_OPERATION", + condition = "UNSUPPORTED_FEATURE.CATALOG_OPERATION", parameters = Map("catalogName" -> "`testcat`", "operation" -> "views")) } @@ -2815,7 +2881,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[CatalogNotFoundException] { sql("SET CATALOG not_exist_catalog") }, - errorClass = "CATALOG_NOT_FOUND", + condition = "CATALOG_NOT_FOUND", parameters = Map( "catalogName" -> "`not_exist_catalog`", "config" -> "\"spark.sql.catalog.not_exist_catalog\"")) @@ -2851,7 +2917,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`non_exist`", @@ -2863,7 +2929,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(sql2) }, - errorClass = "_LEGACY_ERROR_TEMP_1332", + condition = "_LEGACY_ERROR_TEMP_1332", parameters = Map( "errorMessage" -> "CreateIndex is not supported in this table testcat.tbl.")) } @@ -3074,7 +3140,7 @@ class DataSourceV2SQLSuiteV1Filter // a fake time travel implementation that only supports two hardcoded timestamp values. sql("SELECT * FROM t TIMESTAMP AS OF current_date()") }, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`t`"), context = ExpectedContext( fragment = "t", @@ -3085,7 +3151,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql("SELECT * FROM t TIMESTAMP AS OF INTERVAL 1 DAY").collect() }, - errorClass = "INVALID_TIME_TRAVEL_TIMESTAMP_EXPR.INPUT", + condition = "INVALID_TIME_TRAVEL_TIMESTAMP_EXPR.INPUT", parameters = Map( "expr" -> "\"INTERVAL '1' DAY\"")) @@ -3093,14 +3159,14 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql("SELECT * FROM t TIMESTAMP AS OF 'abc'").collect() }, - errorClass = "INVALID_TIME_TRAVEL_TIMESTAMP_EXPR.INPUT", + condition = "INVALID_TIME_TRAVEL_TIMESTAMP_EXPR.INPUT", parameters = Map("expr" -> "\"abc\"")) checkError( exception = intercept[AnalysisException] { spark.read.option("timestampAsOf", "abc").table("t").collect() }, - errorClass = "INVALID_TIME_TRAVEL_TIMESTAMP_EXPR.OPTION", + condition = "INVALID_TIME_TRAVEL_TIMESTAMP_EXPR.OPTION", parameters = Map("expr" -> "'abc'")) checkError( @@ -3111,27 +3177,27 @@ class DataSourceV2SQLSuiteV1Filter .table("t") .collect() }, - errorClass = "INVALID_TIME_TRAVEL_SPEC") + condition = "INVALID_TIME_TRAVEL_SPEC") checkError( exception = intercept[AnalysisException] { sql("SELECT * FROM t TIMESTAMP AS OF current_user()").collect() }, - errorClass = "INVALID_TIME_TRAVEL_TIMESTAMP_EXPR.UNEVALUABLE", + condition = "INVALID_TIME_TRAVEL_TIMESTAMP_EXPR.UNEVALUABLE", parameters = Map("expr" -> "\"current_user()\"")) checkError( exception = intercept[AnalysisException] { sql("SELECT * FROM t TIMESTAMP AS OF CAST(rand() AS STRING)").collect() }, - errorClass = "INVALID_TIME_TRAVEL_TIMESTAMP_EXPR.NON_DETERMINISTIC", + condition = "INVALID_TIME_TRAVEL_TIMESTAMP_EXPR.NON_DETERMINISTIC", parameters = Map("expr" -> "\"CAST(rand() AS STRING)\"")) checkError( exception = intercept[AnalysisException] { sql("SELECT * FROM t TIMESTAMP AS OF abs(true)").collect() }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", sqlState = None, parameters = Map( "sqlExpr" -> "\"abs(true)\"", @@ -3149,7 +3215,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql("SELECT * FROM parquet.`/the/path` VERSION AS OF 1") }, - errorClass = "UNSUPPORTED_FEATURE.TIME_TRAVEL", + condition = "UNSUPPORTED_FEATURE.TIME_TRAVEL", sqlState = None, parameters = Map("relationId" -> "`parquet`.`/the/path`")) @@ -3157,7 +3223,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql("WITH x AS (SELECT 1) SELECT * FROM x VERSION AS OF 1") }, - errorClass = "UNSUPPORTED_FEATURE.TIME_TRAVEL", + condition = "UNSUPPORTED_FEATURE.TIME_TRAVEL", sqlState = None, parameters = Map("relationId" -> "`x`")) @@ -3165,7 +3231,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql("SELECT * FROM non_exist VERSION AS OF 1") }, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`non_exist`"), context = ExpectedContext( fragment = "non_exist", @@ -3177,7 +3243,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"SELECT * FROM t TIMESTAMP AS OF ($subquery1)").collect() }, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`non_exist`"), ExpectedContext( fragment = "non_exist", @@ -3188,7 +3254,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"SELECT * FROM t TIMESTAMP AS OF (SELECT ($subquery1))").collect() }, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`non_exist`"), ExpectedContext( fragment = "non_exist", @@ -3200,7 +3266,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"SELECT * FROM t TIMESTAMP AS OF ($subquery2)").collect() }, - errorClass = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", parameters = Map("objectName" -> "`col`"), ExpectedContext( fragment = "col", @@ -3210,7 +3276,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"SELECT * FROM t TIMESTAMP AS OF (SELECT ($subquery2))").collect() }, - errorClass = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", parameters = Map("objectName" -> "`col`"), ExpectedContext( fragment = "col", @@ -3222,7 +3288,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"SELECT * FROM t TIMESTAMP AS OF ($subquery3)").collect() }, - errorClass = + condition = "INVALID_SUBQUERY_EXPRESSION.SCALAR_SUBQUERY_RETURN_MORE_THAN_ONE_OUTPUT_COLUMN", parameters = Map("number" -> "2"), ExpectedContext( @@ -3233,7 +3299,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"SELECT * FROM t TIMESTAMP AS OF (SELECT ($subquery3))").collect() }, - errorClass = + condition = "INVALID_SUBQUERY_EXPRESSION.SCALAR_SUBQUERY_RETURN_MORE_THAN_ONE_OUTPUT_COLUMN", parameters = Map("number" -> "2"), ExpectedContext( @@ -3246,7 +3312,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[SparkException] { sql(s"SELECT * FROM t TIMESTAMP AS OF ($subquery4)").collect() }, - errorClass = "SCALAR_SUBQUERY_TOO_MANY_ROWS", + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS", parameters = Map.empty, ExpectedContext( fragment = "(SELECT * FROM VALUES (1), (2))", @@ -3256,7 +3322,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[SparkException] { sql(s"SELECT * FROM t TIMESTAMP AS OF (SELECT ($subquery4))").collect() }, - errorClass = "SCALAR_SUBQUERY_TOO_MANY_ROWS", + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS", parameters = Map.empty, ExpectedContext( fragment = "(SELECT * FROM VALUES (1), (2))", @@ -3360,7 +3426,7 @@ class DataSourceV2SQLSuiteV1Filter } checkError( exception, - errorClass = "UNSUPPORTED_FEATURE.OVERWRITE_BY_SUBQUERY", + condition = "UNSUPPORTED_FEATURE.OVERWRITE_BY_SUBQUERY", sqlState = "0A000", parameters = Map.empty, context = ExpectedContext( @@ -3781,7 +3847,7 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"$sqlCommand $sqlParams") }, - errorClass = "NOT_SUPPORTED_COMMAND_FOR_V2_TABLE", + condition = "NOT_SUPPORTED_COMMAND_FOR_V2_TABLE", sqlState = "0A000", parameters = Map("cmd" -> expectedArgument.getOrElse(sqlCommand))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 1de535df246b7..d61d554025e50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -454,7 +454,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS .write.format(cls.getName) .option("path", path).mode("ignore").save() }, - errorClass = "UNSUPPORTED_DATA_SOURCE_SAVE_MODE", + condition = "UNSUPPORTED_DATA_SOURCE_SAVE_MODE", parameters = Map( "source" -> cls.getName, "createMode" -> "\"Ignore\"" @@ -467,7 +467,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS .write.format(cls.getName) .option("path", path).mode("error").save() }, - errorClass = "UNSUPPORTED_DATA_SOURCE_SAVE_MODE", + condition = "UNSUPPORTED_DATA_SOURCE_SAVE_MODE", parameters = Map( "source" -> cls.getName, "createMode" -> "\"ErrorIfExists\"" @@ -651,7 +651,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS exception = intercept[SparkUnsupportedOperationException] { sql(s"CREATE TABLE test(a INT, b INT) USING ${cls.getName}") }, - errorClass = "CANNOT_CREATE_DATA_SOURCE_TABLE.EXTERNAL_METADATA_UNSUPPORTED", + condition = "CANNOT_CREATE_DATA_SOURCE_TABLE.EXTERNAL_METADATA_UNSUPPORTED", parameters = Map("tableName" -> "`default`.`test`", "provider" -> cls.getName) ) } @@ -732,7 +732,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS exception = intercept[AnalysisException] { sql(s"CREATE TABLE test (x INT, y INT) USING ${cls.getName}") }, - errorClass = "DATA_SOURCE_TABLE_SCHEMA_MISMATCH", + condition = "DATA_SOURCE_TABLE_SCHEMA_MISMATCH", parameters = Map( "dsSchema" -> "\"STRUCT\"", "expectedSchema" -> "\"STRUCT\"")) @@ -770,7 +770,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS exception = intercept[AnalysisException] { sql(s"CREATE TABLE test USING ${cls.getName} AS VALUES (0, 1), (1, 2)") }, - errorClass = "DATA_SOURCE_TABLE_SCHEMA_MISMATCH", + condition = "DATA_SOURCE_TABLE_SCHEMA_MISMATCH", parameters = Map( "dsSchema" -> "\"STRUCT\"", "expectedSchema" -> "\"STRUCT\"")) @@ -788,7 +788,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS |AS VALUES ('a', 'b'), ('c', 'd') t(i, j) |""".stripMargin) }, - errorClass = "DATA_SOURCE_TABLE_SCHEMA_MISMATCH", + condition = "DATA_SOURCE_TABLE_SCHEMA_MISMATCH", parameters = Map( "dsSchema" -> "\"STRUCT\"", "expectedSchema" -> "\"STRUCT\"")) @@ -839,7 +839,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS exception = intercept[SparkUnsupportedOperationException] { sql(s"CREATE TABLE test USING ${cls.getName} AS VALUES (0, 1)") }, - errorClass = "CANNOT_CREATE_DATA_SOURCE_TABLE.EXTERNAL_METADATA_UNSUPPORTED", + condition = "CANNOT_CREATE_DATA_SOURCE_TABLE.EXTERNAL_METADATA_UNSUPPORTED", parameters = Map( "tableName" -> "`default`.`test`", "provider" -> "org.apache.spark.sql.connector.SimpleDataSourceV2")) @@ -851,7 +851,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS exception = intercept[AnalysisException] { sql(s"CREATE TABLE test USING ${cls.getName} AS SELECT * FROM VALUES (0, 1)") }, - errorClass = "UNSUPPORTED_FEATURE.TABLE_OPERATION", + condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`test`", "operation" -> "append in batch mode")) @@ -881,7 +881,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS exception = intercept[AnalysisException] { sql(s"INSERT INTO test VALUES (4)") }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`test`", "tableColumns" -> "`x`, `y`", @@ -893,7 +893,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS exception = intercept[AnalysisException] { sql(s"INSERT INTO test(x, x) VALUES (4, 5)") }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`x`")) } } @@ -935,13 +935,13 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS exception = intercept[AnalysisException] { sql("INSERT INTO test PARTITION(z = 1) VALUES (2)") }, - errorClass = "NON_PARTITION_COLUMN", + condition = "NON_PARTITION_COLUMN", parameters = Map("columnName" -> "`z`")) checkError( exception = intercept[AnalysisException] { sql("INSERT INTO test PARTITION(x, y = 1) VALUES (2, 3)") }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`test`", "tableColumns" -> "`x`, `y`", @@ -959,7 +959,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS exception = intercept[AnalysisException] { sql("INSERT OVERWRITE test PARTITION(x = 1) VALUES (5)") }, - errorClass = "UNSUPPORTED_FEATURE.TABLE_OPERATION", + condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`test`", "operation" -> "overwrite by filter in batch mode") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTests.scala index eeef0566b8faf..fd022580db42b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTests.scala @@ -112,7 +112,7 @@ trait DeleteFromTests extends DatasourceV2SQLBase { checkError( exception = exc, - errorClass = "UNSUPPORTED_FEATURE.TABLE_OPERATION", + condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", sqlState = "0A000", parameters = Map("tableName" -> "`spark_catalog`.`default`.`tbl`", "operation" -> "DELETE")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala index 7336b3a6e9206..be180eb89ce20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala @@ -37,7 +37,7 @@ class DeltaBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { checkError( exception = intercept[AnalysisException]( sql(s"DELETE FROM $tableNameAsString WHERE id <= 1 AND rand() > 0.5")), - errorClass = "INVALID_NON_DETERMINISTIC_EXPRESSIONS", + condition = "INVALID_NON_DETERMINISTIC_EXPRESSIONS", parameters = Map( "sqlExprs" -> "\"((id <= 1) AND (rand() > 0.5))\""), context = ExpectedContext( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala index 46942cac1c7e3..89b42b5e6db7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala @@ -32,7 +32,7 @@ abstract class DeltaBasedUpdateTableSuiteBase extends UpdateTableSuiteBase { exception = intercept[AnalysisException] { sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE pk = 1") }, - errorClass = "NULLABLE_ROW_ID_ATTRIBUTES", + condition = "NULLABLE_ROW_ID_ATTRIBUTES", parameters = Map("nullableRowIdAttrs" -> "pk#\\d+") ) } @@ -62,7 +62,7 @@ abstract class DeltaBasedUpdateTableSuiteBase extends UpdateTableSuiteBase { exception = intercept[AnalysisException] { sql(s"UPDATE $tableNameAsString SET dep = 'invalid' WHERE id <= 1 AND rand() > 0.5") }, - errorClass = "INVALID_NON_DETERMINISTIC_EXPRESSIONS", + condition = "INVALID_NON_DETERMINISTIC_EXPRESSIONS", parameters = Map("sqlExprs" -> "\"((id <= 1) AND (rand() > 0.5))\""), context = ExpectedContext( fragment = "UPDATE cat.ns1.test_table SET dep = 'invalid' WHERE id <= 1 AND rand() > 0.5", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala index c6060dcdd51a7..2a0ab21ddb09c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala @@ -103,7 +103,7 @@ class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkException] { spark.read.format(dummyReadOnlyFileSourceV2).load(path).collect() }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "Dummy file reader")) } } @@ -131,7 +131,7 @@ class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkException] { spark.read.format(dummyReadOnlyFileSourceV2).load(path).collect() }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "Dummy file reader")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala index 0aeab95f58a7b..1be318f948fd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala @@ -34,7 +34,7 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { checkError( exception = intercept[AnalysisException]( sql(s"DELETE FROM $tableNameAsString WHERE id <= 1 AND rand() > 0.5")), - errorClass = "INVALID_NON_DETERMINISTIC_EXPRESSIONS", + condition = "INVALID_NON_DETERMINISTIC_EXPRESSIONS", parameters = Map( "sqlExprs" -> "\"((id <= 1) AND (rand() > 0.5))\", \"((id <= 1) AND (rand() > 0.5))\""), context = ExpectedContext( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedUpdateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedUpdateTableSuite.scala index 3e736421a315c..774ae97734d25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedUpdateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedUpdateTableSuite.scala @@ -122,7 +122,7 @@ class GroupBasedUpdateTableSuite extends UpdateTableSuiteBase { exception = intercept[AnalysisException] { sql(s"UPDATE $tableNameAsString SET dep = 'invalid' WHERE id <= 1 AND rand() > 0.5") }, - errorClass = "INVALID_NON_DETERMINISTIC_EXPRESSIONS", + condition = "INVALID_NON_DETERMINISTIC_EXPRESSIONS", parameters = Map( "sqlExprs" -> "\"((id <= 1) AND (rand() > 0.5))\", \"((id <= 1) AND (rand() > 0.5))\""), context = ExpectedContext( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala index fa30969d65c52..d6e86bc93c9d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala @@ -135,7 +135,7 @@ abstract class InsertIntoTests( exception = intercept[AnalysisException] { doInsert(t1, df) }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", parameters = Map( "tableName" -> tableName, "tableColumns" -> "`id`, `data`, `missing`", @@ -158,7 +158,7 @@ abstract class InsertIntoTests( exception = intercept[AnalysisException] { doInsert(t1, df) }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS", parameters = Map( "tableName" -> tableName, "tableColumns" -> "`id`, `data`", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala index c080a66bce257..8aa8fb21f4ae3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.Row import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.MergeIntoWriterImpl class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { @@ -950,7 +951,7 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { // an arbitrary merge val writer1 = spark.table("source") - .mergeInto("dummy", $"col" === $"col") + .mergeInto("dummy", $"colA" === $"colB") .whenMatched(col("col") === 1) .updateAll() .whenMatched() @@ -959,16 +960,15 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { .insertAll() .whenNotMatchedBySource(col("col") === 1) .delete() + .asInstanceOf[MergeIntoWriterImpl[Row]] val writer2 = writer1.withSchemaEvolution() + .asInstanceOf[MergeIntoWriterImpl[Row]] + assert(writer1 eq writer2) assert(writer1.matchedActions.length === 2) assert(writer1.notMatchedActions.length === 1) assert(writer1.notMatchedBySourceActions.length === 1) - - assert(writer1.matchedActions === writer2.matchedActions) - assert(writer1.notMatchedActions === writer2.notMatchedActions) - assert(writer1.notMatchedBySourceActions === writer2.notMatchedBySourceActions) - assert(writer2.schemaEvolutionEnabled) + assert(writer1.schemaEvolutionEnabled) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala index b043bf2f5be23..741e30a739f5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala @@ -303,7 +303,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { exception = intercept[AnalysisException] { df.metadataColumn("foo") }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`foo`", "proposal" -> "`index`, `_partition`"), queryContext = Array(ExpectedContext("select index from testcat.t", 0, 26))) @@ -312,7 +312,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { exception = intercept[AnalysisException] { df.metadataColumn("data") }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`data`", "proposal" -> "`index`, `_partition`"), queryContext = Array(ExpectedContext("select index from testcat.t", 0, 26))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala new file mode 100644 index 0000000000000..c8faf5a874f5f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala @@ -0,0 +1,656 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util.Collections + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkNumberFormatException} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId +import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog} +import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter, UnboundProcedure} +import org.apache.spark.sql.connector.catalog.procedures.ProcedureParameter.Mode +import org.apache.spark.sql.connector.catalog.procedures.ProcedureParameter.Mode.{IN, INOUT, OUT} +import org.apache.spark.sql.connector.read.{LocalScan, Scan} +import org.apache.spark.sql.errors.DataTypeErrors.{toSQLType, toSQLValue} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String + +class ProcedureSuite extends QueryTest with SharedSparkSession with BeforeAndAfter { + + before { + spark.conf.set(s"spark.sql.catalog.cat", classOf[InMemoryCatalog].getName) + } + + after { + spark.sessionState.catalogManager.reset() + spark.sessionState.conf.unsetConf(s"spark.sql.catalog.cat") + } + + private def catalog: InMemoryCatalog = { + val catalog = spark.sessionState.catalogManager.catalog("cat") + catalog.asInstanceOf[InMemoryCatalog] + } + + test("position arguments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer(sql("CALL cat.ns.sum(5, 5)"), Row(10) :: Nil) + } + + test("named arguments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer(sql("CALL cat.ns.sum(in2 => 3, in1 => 5)"), Row(8) :: Nil) + } + + test("position and named arguments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer(sql("CALL cat.ns.sum(3, in2 => 1)"), Row(4) :: Nil) + } + + test("foldable expressions") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer(sql("CALL cat.ns.sum(1 + 1, in2 => 2)"), Row(4) :: Nil) + checkAnswer(sql("CALL cat.ns.sum(in2 => 1, in1 => 2 + 1)"), Row(4) :: Nil) + checkAnswer(sql("CALL cat.ns.sum((1 + 1) * 2, in2 => (2 + 1) / 3)"), Row(5) :: Nil) + } + + test("type coercion") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundLongSum) + checkAnswer(sql("CALL cat.ns.sum(1, 2)"), Row(3) :: Nil) + checkAnswer(sql("CALL cat.ns.sum(1L, 2)"), Row(3) :: Nil) + checkAnswer(sql("CALL cat.ns.sum(1, 2L)"), Row(3) :: Nil) + } + + test("multiple output rows") { + catalog.createProcedure(Identifier.of(Array("ns"), "complex"), UnboundComplexProcedure) + checkAnswer( + sql("CALL cat.ns.complex('X', 'Y', 3)"), + Row(1, "X1", "Y1") :: Row(2, "X2", "Y2") :: Row(3, "X3", "Y3") :: Nil) + } + + test("parameters with default values") { + catalog.createProcedure(Identifier.of(Array("ns"), "complex"), UnboundComplexProcedure) + checkAnswer(sql("CALL cat.ns.complex()"), Row(1, "A1", "B1") :: Nil) + checkAnswer(sql("CALL cat.ns.complex('X', 'Y')"), Row(1, "X1", "Y1") :: Nil) + } + + test("parameters with invalid default values") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundInvalidDefaultProcedure) + checkError( + exception = intercept[AnalysisException]( + sql("CALL cat.ns.sum()") + ), + condition = "INVALID_DEFAULT_VALUE.DATA_TYPE", + parameters = Map( + "statement" -> "CALL", + "colName" -> toSQLId("in2"), + "defaultValue" -> toSQLValue("B"), + "expectedType" -> toSQLType("INT"), + "actualType" -> toSQLType("STRING"))) + } + + test("IDENTIFIER") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer( + spark.sql("CALL IDENTIFIER(:p1)(1, 2)", Map("p1" -> "cat.ns.sum")), + Row(3) :: Nil) + } + + test("parameterized statements") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer( + spark.sql("CALL cat.ns.sum(?, ?)", Array(2, 3)), + Row(5) :: Nil) + } + + test("undefined procedure") { + checkError( + exception = intercept[AnalysisException]( + sql("CALL cat.non_exist(1, 2)") + ), + sqlState = Some("38000"), + condition = "FAILED_TO_LOAD_ROUTINE", + parameters = Map("routineName" -> "`cat`.`non_exist`") + ) + } + + test("non-procedure catalog") { + withSQLConf("spark.sql.catalog.testcat" -> classOf[BasicInMemoryTableCatalog].getName) { + checkError( + exception = intercept[AnalysisException]( + sql("CALL testcat.procedure(1, 2)") + ), + condition = "_LEGACY_ERROR_TEMP_1184", + parameters = Map("plugin" -> "testcat", "ability" -> "procedures") + ) + } + } + + test("too many arguments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException]( + sql("CALL cat.ns.sum(1, 2, 3)") + ), + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + parameters = Map( + "functionName" -> toSQLId("sum"), + "expectedNum" -> "2", + "actualNum" -> "3", + "docroot" -> SPARK_DOC_ROOT)) + } + + test("custom default catalog") { + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "cat") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + val df = sql("CALL ns.sum(1, 2)") + checkAnswer(df, Row(3) :: Nil) + } + } + + test("custom default catalog and namespace") { + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "cat") { + catalog.createNamespace(Array("ns"), Collections.emptyMap) + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + sql("USE ns") + val df = sql("CALL sum(1, 2)") + checkAnswer(df, Row(3) :: Nil) + } + } + + test("required parameter not found") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum()") + }, + condition = "REQUIRED_PARAMETER_NOT_FOUND", + parameters = Map( + "routineName" -> toSQLId("sum"), + "parameterName" -> toSQLId("in1"), + "index" -> "0")) + } + + test("conflicting position and named parameter assignments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum(1, in1 => 2)") + }, + condition = "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.BOTH_POSITIONAL_AND_NAMED", + parameters = Map( + "routineName" -> toSQLId("sum"), + "parameterName" -> toSQLId("in1"))) + } + + test("duplicate named parameter assignments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum(in1 => 1, in1 => 2)") + }, + condition = "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + parameters = Map( + "routineName" -> toSQLId("sum"), + "parameterName" -> toSQLId("in1"))) + } + + test("unknown parameter name") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum(in1 => 1, in5 => 2)") + }, + condition = "UNRECOGNIZED_PARAMETER_NAME", + parameters = Map( + "routineName" -> toSQLId("sum"), + "argumentName" -> toSQLId("in5"), + "proposal" -> (toSQLId("in1") + " " + toSQLId("in2")))) + } + + test("position parameter after named parameter") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum(in1 => 1, 2)") + }, + condition = "UNEXPECTED_POSITIONAL_ARGUMENT", + parameters = Map( + "routineName" -> toSQLId("sum"), + "parameterName" -> toSQLId("in1"))) + } + + test("invalid argument type") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + val call = "CALL cat.ns.sum(1, TIMESTAMP '2016-11-15 20:54:00.000')" + checkError( + exception = intercept[AnalysisException] { + sql(call) + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "CALL", + "paramIndex" -> "second", + "inputSql" -> "\"TIMESTAMP '2016-11-15 20:54:00'\"", + "inputType" -> toSQLType("TIMESTAMP"), + "requiredType" -> toSQLType("INT")), + context = ExpectedContext(fragment = call, start = 0, stop = call.length - 1)) + } + + test("malformed input to implicit cast") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> true.toString) { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + val call = "CALL cat.ns.sum('A', 2)" + checkError( + exception = intercept[SparkNumberFormatException]( + sql(call) + ), + condition = "CAST_INVALID_INPUT", + parameters = Map( + "expression" -> toSQLValue("A"), + "sourceType" -> toSQLType("STRING"), + "targetType" -> toSQLType("INT")), + context = ExpectedContext(fragment = call, start = 0, stop = call.length - 1)) + } + } + + test("required parameters after optional") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundInvalidSum) + val e = intercept[SparkException] { + sql("CALL cat.ns.sum(in2 => 1)") + } + assert(e.getMessage.contains("required arguments should come before optional arguments")) + } + + test("INOUT parameters are not supported") { + catalog.createProcedure(Identifier.of(Array("ns"), "procedure"), UnboundInoutProcedure) + val e = intercept[SparkException] { + sql("CALL cat.ns.procedure(1)") + } + assert(e.getMessage.contains(" Unsupported parameter mode: INOUT")) + } + + test("OUT parameters are not supported") { + catalog.createProcedure(Identifier.of(Array("ns"), "procedure"), UnboundOutProcedure) + val e = intercept[SparkException] { + sql("CALL cat.ns.procedure(1)") + } + assert(e.getMessage.contains("Unsupported parameter mode: OUT")) + } + + test("EXPLAIN") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundNonExecutableSum) + val explain1 = sql("EXPLAIN CALL cat.ns.sum(5, 5)").head().get(0) + assert(explain1.toString.contains("cat.ns.sum(5, 5)")) + val explain2 = sql("EXPLAIN EXTENDED CALL cat.ns.sum(10, 10)").head().get(0) + assert(explain2.toString.contains("cat.ns.sum(10, 10)")) + } + + test("void procedure") { + catalog.createProcedure(Identifier.of(Array("ns"), "proc"), UnboundVoidProcedure) + checkAnswer(sql("CALL cat.ns.proc('A', 'B')"), Nil) + } + + test("multi-result procedure") { + catalog.createProcedure(Identifier.of(Array("ns"), "proc"), UnboundMultiResultProcedure) + checkAnswer(sql("CALL cat.ns.proc()"), Row("last") :: Nil) + } + + test("invalid input to struct procedure") { + catalog.createProcedure(Identifier.of(Array("ns"), "proc"), UnboundStructProcedure) + val actualType = + StructType(Seq( + StructField("X", DataTypes.DateType, nullable = false), + StructField("Y", DataTypes.IntegerType, nullable = false))) + val expectedType = StructProcedure.parameters.head.dataType + val call = "CALL cat.ns.proc(named_struct('X', DATE '2011-11-11', 'Y', 2), 'VALUE')" + checkError( + exception = intercept[AnalysisException](sql(call)), + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "CALL", + "paramIndex" -> "first", + "inputSql" -> "\"named_struct(X, DATE '2011-11-11', Y, 2)\"", + "inputType" -> toSQLType(actualType), + "requiredType" -> toSQLType(expectedType)), + context = ExpectedContext(fragment = call, start = 0, stop = call.length - 1)) + } + + test("save execution summary") { + withTable("summary") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + val result = sql("CALL cat.ns.sum(1, 2)") + result.write.saveAsTable("summary") + checkAnswer(spark.table("summary"), Row(3) :: Nil) + } + } + + object UnboundVoidProcedure extends UnboundProcedure { + override def name: String = "void" + override def description: String = "void procedure" + override def bind(inputType: StructType): BoundProcedure = VoidProcedure + } + + object VoidProcedure extends BoundProcedure { + override def name: String = "void" + + override def description: String = "void procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.StringType).build(), + ProcedureParameter.in("in2", DataTypes.StringType).build() + ) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + Collections.emptyIterator + } + } + + object UnboundMultiResultProcedure extends UnboundProcedure { + override def name: String = "multi" + override def description: String = "multi-result procedure" + override def bind(inputType: StructType): BoundProcedure = MultiResultProcedure + } + + object MultiResultProcedure extends BoundProcedure { + override def name: String = "multi" + + override def description: String = "multi-result procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array() + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + val scans = java.util.Arrays.asList[Scan]( + Result( + new StructType().add("out", DataTypes.IntegerType), + Array(InternalRow(1))), + Result( + new StructType().add("out", DataTypes.StringType), + Array(InternalRow(UTF8String.fromString("last")))) + ) + scans.iterator() + } + } + + object UnboundNonExecutableSum extends UnboundProcedure { + override def name: String = "sum" + override def description: String = "sum integers" + override def bind(inputType: StructType): BoundProcedure = Sum + } + + object NonExecutableSum extends BoundProcedure { + override def name: String = "sum" + + override def description: String = "sum integers" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.IntegerType).build(), + ProcedureParameter.in("in2", DataTypes.IntegerType).build() + ) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + object UnboundSum extends UnboundProcedure { + override def name: String = "sum" + override def description: String = "sum integers" + override def bind(inputType: StructType): BoundProcedure = Sum + } + + object Sum extends BoundProcedure { + override def name: String = "sum" + + override def description: String = "sum integers" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.IntegerType).build(), + ProcedureParameter.in("in2", DataTypes.IntegerType).build() + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + val in1 = input.getInt(0) + val in2 = input.getInt(1) + val result = Result(outputType, Array(InternalRow(in1 + in2))) + Collections.singleton[Scan](result).iterator() + } + } + + object UnboundLongSum extends UnboundProcedure { + override def name: String = "long_sum" + override def description: String = "sum longs" + override def bind(inputType: StructType): BoundProcedure = LongSum + } + + object LongSum extends BoundProcedure { + override def name: String = "long_sum" + + override def description: String = "sum longs" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.LongType).build(), + ProcedureParameter.in("in2", DataTypes.LongType).build() + ) + + def outputType: StructType = new StructType().add("out", DataTypes.LongType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + val in1 = input.getLong(0) + val in2 = input.getLong(1) + val result = Result(outputType, Array(InternalRow(in1 + in2))) + Collections.singleton[Scan](result).iterator() + } + } + + object UnboundInvalidSum extends UnboundProcedure { + override def name: String = "invalid" + override def description: String = "sum integers" + override def bind(inputType: StructType): BoundProcedure = InvalidSum + } + + object InvalidSum extends BoundProcedure { + override def name: String = "invalid" + + override def description: String = "sum integers" + + override def isDeterministic: Boolean = false + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.IntegerType).defaultValue("1").build(), + ProcedureParameter.in("in2", DataTypes.IntegerType).build() + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + object UnboundInvalidDefaultProcedure extends UnboundProcedure { + override def name: String = "sum" + override def description: String = "invalid default value procedure" + override def bind(inputType: StructType): BoundProcedure = InvalidDefaultProcedure + } + + object InvalidDefaultProcedure extends BoundProcedure { + override def name: String = "sum" + + override def description: String = "invalid default value procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.IntegerType).defaultValue("10").build(), + ProcedureParameter.in("in2", DataTypes.IntegerType).defaultValue("'B'").build() + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + object UnboundComplexProcedure extends UnboundProcedure { + override def name: String = "complex" + override def description: String = "complex procedure" + override def bind(inputType: StructType): BoundProcedure = ComplexProcedure + } + + object ComplexProcedure extends BoundProcedure { + override def name: String = "complex" + + override def description: String = "complex procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.StringType).defaultValue("'A'").build(), + ProcedureParameter.in("in2", DataTypes.StringType).defaultValue("'B'").build(), + ProcedureParameter.in("in3", DataTypes.IntegerType).defaultValue("1 + 1 - 1").build() + ) + + def outputType: StructType = new StructType() + .add("out1", DataTypes.IntegerType) + .add("out2", DataTypes.StringType) + .add("out3", DataTypes.StringType) + + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + val in1 = input.getString(0) + val in2 = input.getString(1) + val in3 = input.getInt(2) + + val rows = (1 to in3).map { index => + val v1 = UTF8String.fromString(s"$in1$index") + val v2 = UTF8String.fromString(s"$in2$index") + InternalRow(index, v1, v2) + }.toArray + + val result = Result(outputType, rows) + Collections.singleton[Scan](result).iterator() + } + } + + object UnboundStructProcedure extends UnboundProcedure { + override def name: String = "struct_input" + override def description: String = "struct procedure" + override def bind(inputType: StructType): BoundProcedure = StructProcedure + } + + object StructProcedure extends BoundProcedure { + override def name: String = "struct_input" + + override def description: String = "struct procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter + .in( + "in1", + StructType(Seq( + StructField("nested1", DataTypes.IntegerType), + StructField("nested2", DataTypes.StringType)))) + .build(), + ProcedureParameter.in("in2", DataTypes.StringType).build() + ) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + Collections.emptyIterator + } + } + + object UnboundInoutProcedure extends UnboundProcedure { + override def name: String = "procedure" + override def description: String = "inout procedure" + override def bind(inputType: StructType): BoundProcedure = InoutProcedure + } + + object InoutProcedure extends BoundProcedure { + override def name: String = "procedure" + + override def description: String = "inout procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + CustomParameterImpl(INOUT, "in1", DataTypes.IntegerType) + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + object UnboundOutProcedure extends UnboundProcedure { + override def name: String = "procedure" + override def description: String = "out procedure" + override def bind(inputType: StructType): BoundProcedure = OutProcedure + } + + object OutProcedure extends BoundProcedure { + override def name: String = "procedure" + + override def description: String = "out procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + CustomParameterImpl(IN, "in1", DataTypes.IntegerType), + CustomParameterImpl(OUT, "out1", DataTypes.IntegerType) + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + case class Result(readSchema: StructType, rows: Array[InternalRow]) extends LocalScan + + case class CustomParameterImpl( + mode: Mode, + name: String, + dataType: DataType) extends ProcedureParameter { + override def defaultValueExpression: String = null + override def comment: String = null + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala index ff944dbb805cb..2254abef3fcb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala @@ -22,8 +22,7 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.catalog.CatalogTableType -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, DelegatingCatalogExtension, Identifier, Table, TableCatalog, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, DelegatingCatalogExtension, Identifier, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.types.StructType @@ -53,14 +52,10 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating if (tables.containsKey(ident)) { tables.get(ident) } else { - // Table was created through the built-in catalog - super.loadTable(ident) match { - case v1Table: V1Table if v1Table.v1Table.tableType == CatalogTableType.VIEW => v1Table - case t => - val table = newTable(t.name(), t.schema(), t.partitioning(), t.properties()) - addTable(ident, table) - table - } + // Table was created through the built-in catalog via v1 command, this is OK as the + // `loadTable` should always be invoked, and we set the `tableCreated` to pass validation. + tableCreated.set(true) + super.loadTable(ident) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala index c2ae5f40cfaf6..f659ca6329e2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala @@ -579,7 +579,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { exception = intercept[SparkRuntimeException] { sql(s"UPDATE $tableNameAsString SET s = named_struct('n_i', null, 'n_l', -1L) WHERE pk = 1") }, - errorClass = "NOT_NULL_ASSERT_VIOLATION", + condition = "NOT_NULL_ASSERT_VIOLATION", sqlState = "42000", parameters = Map("walkedTypePath" -> "\ns\nn_i\n")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala index b4a768a75989a..5091c72ef96ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala @@ -155,7 +155,7 @@ class V2CommandsCaseSensitivitySuite Seq(QualifiedColType( Some(UnresolvedFieldName(field.init.toImmutableArraySeq)), field.last, LongType, true, None, None, None))), - expectedErrorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + expectedErrorCondition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", expectedMessageParameters = Map( "objectName" -> s"`${field.head}`", "proposal" -> "`id`, `data`, `point`") @@ -177,9 +177,9 @@ class V2CommandsCaseSensitivitySuite None))) Seq(true, false).foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = alter, - expectedErrorClass = "FIELD_NOT_FOUND", + expectedErrorCondition = "FIELD_NOT_FOUND", expectedMessageParameters = Map("fieldName" -> "`f`", "fields" -> "id, data, point") ) } @@ -208,9 +208,9 @@ class V2CommandsCaseSensitivitySuite None))) Seq(true, false).foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = alter, - expectedErrorClass = "FIELD_NOT_FOUND", + expectedErrorCondition = "FIELD_NOT_FOUND", expectedMessageParameters = Map("fieldName" -> "`y`", "fields" -> "id, data, point, x") ) } @@ -231,9 +231,9 @@ class V2CommandsCaseSensitivitySuite None))) Seq(true, false).foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = alter, - expectedErrorClass = "FIELD_NOT_FOUND", + expectedErrorCondition = "FIELD_NOT_FOUND", expectedMessageParameters = Map("fieldName" -> "`z`", "fields" -> "x, y") ) } @@ -262,9 +262,9 @@ class V2CommandsCaseSensitivitySuite None))) Seq(true, false).foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( inputPlan = alter, - expectedErrorClass = "FIELD_NOT_FOUND", + expectedErrorCondition = "FIELD_NOT_FOUND", expectedMessageParameters = Map("fieldName" -> "`zz`", "fields" -> "x, y, z") ) } @@ -272,7 +272,7 @@ class V2CommandsCaseSensitivitySuite } test("SPARK-36372: Adding duplicate columns should not be allowed") { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( AddColumns( table, Seq(QualifiedColType( @@ -297,7 +297,7 @@ class V2CommandsCaseSensitivitySuite } test("SPARK-36381: Check column name exist case sensitive and insensitive when add column") { - alterTableErrorClass( + alterTableErrorCondition( AddColumns( table, Seq(QualifiedColType( @@ -317,7 +317,7 @@ class V2CommandsCaseSensitivitySuite } test("SPARK-36381: Check column name exist case sensitive and insensitive when rename column") { - alterTableErrorClass( + alterTableErrorCondition( RenameColumn(table, UnresolvedFieldName(Array("id").toImmutableArraySeq), "DATA"), "FIELD_ALREADY_EXISTS", Map( @@ -338,7 +338,7 @@ class V2CommandsCaseSensitivitySuite } else { alterTableTest( alter = alter, - expectedErrorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + expectedErrorCondition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", expectedMessageParameters = Map( "objectName" -> s"${toSQLId(ref.toImmutableArraySeq)}", "proposal" -> "`id`, `data`, `point`" @@ -353,7 +353,7 @@ class V2CommandsCaseSensitivitySuite Seq(Array("ID"), Array("point", "X"), Array("POINT", "X"), Array("POINT", "x")).foreach { ref => alterTableTest( alter = RenameColumn(table, UnresolvedFieldName(ref.toImmutableArraySeq), "newName"), - expectedErrorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + expectedErrorCondition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", expectedMessageParameters = Map( "objectName" -> s"${toSQLId(ref.toImmutableArraySeq)}", "proposal" -> "`id`, `data`, `point`") @@ -366,7 +366,7 @@ class V2CommandsCaseSensitivitySuite alterTableTest( AlterColumn(table, UnresolvedFieldName(ref.toImmutableArraySeq), None, Some(true), None, None, None), - expectedErrorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + expectedErrorCondition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", expectedMessageParameters = Map( "objectName" -> s"${toSQLId(ref.toImmutableArraySeq)}", "proposal" -> "`id`, `data`, `point`") @@ -379,7 +379,7 @@ class V2CommandsCaseSensitivitySuite alterTableTest( AlterColumn(table, UnresolvedFieldName(ref.toImmutableArraySeq), Some(StringType), None, None, None, None), - expectedErrorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + expectedErrorCondition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", expectedMessageParameters = Map( "objectName" -> s"${toSQLId(ref.toImmutableArraySeq)}", "proposal" -> "`id`, `data`, `point`") @@ -392,7 +392,7 @@ class V2CommandsCaseSensitivitySuite alterTableTest( AlterColumn(table, UnresolvedFieldName(ref.toImmutableArraySeq), None, None, Some("comment"), None, None), - expectedErrorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + expectedErrorCondition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", expectedMessageParameters = Map( "objectName" -> s"${toSQLId(ref.toImmutableArraySeq)}", "proposal" -> "`id`, `data`, `point`") @@ -401,7 +401,7 @@ class V2CommandsCaseSensitivitySuite } test("SPARK-36449: Replacing columns with duplicate name should not be allowed") { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( ReplaceColumns( table, Seq(QualifiedColType(None, "f", LongType, true, None, None, None), @@ -413,15 +413,15 @@ class V2CommandsCaseSensitivitySuite private def alterTableTest( alter: => AlterTableCommand, - expectedErrorClass: String, + expectedErrorCondition: String, expectedMessageParameters: Map[String, String], expectErrorOnCaseSensitive: Boolean = true): Unit = { Seq(true, false).foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { val expectError = if (expectErrorOnCaseSensitive) caseSensitive else !caseSensitive if (expectError) { - assertAnalysisErrorClass( - alter, expectedErrorClass, expectedMessageParameters, caseSensitive = caseSensitive) + assertAnalysisErrorCondition( + alter, expectedErrorCondition, expectedMessageParameters, caseSensitive = caseSensitive) } else { assertAnalysisSuccess(alter, caseSensitive) } @@ -429,17 +429,17 @@ class V2CommandsCaseSensitivitySuite } } - private def alterTableErrorClass( + private def alterTableErrorCondition( alter: => AlterTableCommand, - errorClass: String, + condition: String, messageParameters: Map[String, String], expectErrorOnCaseSensitive: Boolean = true): Unit = { Seq(true, false).foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { val expectError = if (expectErrorOnCaseSensitive) caseSensitive else !caseSensitive if (expectError) { - assertAnalysisErrorClass( - alter, errorClass, messageParameters, caseSensitive = caseSensitive) + assertAnalysisErrorCondition( + alter, condition, messageParameters, caseSensitive = caseSensitive) } else { assertAnalysisSuccess(alter, caseSensitive) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsDSv2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsDSv2Suite.scala index 39809c785af92..48d4e45ebf354 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsDSv2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsDSv2Suite.scala @@ -51,7 +51,7 @@ class QueryCompilationErrorsDSv2Suite checkAnswer(spark.table(tbl), spark.emptyDataFrame) checkError( exception = e, - errorClass = "UNSUPPORTED_FEATURE.INSERT_PARTITION_SPEC_IF_NOT_EXISTS", + condition = "UNSUPPORTED_FEATURE.INSERT_PARTITION_SPEC_IF_NOT_EXISTS", parameters = Map("tableName" -> "`testcat`.`ns1`.`ns2`.`tbl`"), sqlState = "0A000") } @@ -70,7 +70,7 @@ class QueryCompilationErrorsDSv2Suite verifyTable(t1, spark.emptyDataFrame) checkError( exception = e, - errorClass = "NON_PARTITION_COLUMN", + condition = "NON_PARTITION_COLUMN", parameters = Map("columnName" -> "`id`")) } } @@ -87,7 +87,7 @@ class QueryCompilationErrorsDSv2Suite verifyTable(t1, spark.emptyDataFrame) checkError( exception = e, - errorClass = "NON_PARTITION_COLUMN", + condition = "NON_PARTITION_COLUMN", parameters = Map("columnName" -> "`data`")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 371a615828de3..832e1873af6a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.api.java.{UDF1, UDF2, UDF23Test} import org.apache.spark.sql.catalyst.expressions.{Coalesce, Literal, UnsafeRow} import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.execution.datasources.SaveIntoDataSourceCommand import org.apache.spark.sql.execution.datasources.parquet.SparkToParquetSchemaConverter import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.functions._ @@ -51,7 +52,7 @@ class QueryCompilationErrorsSuite } checkError( exception = e1, - errorClass = "CANNOT_UP_CAST_DATATYPE", + condition = "CANNOT_UP_CAST_DATATYPE", parameters = Map("expression" -> "b", "sourceType" -> "\"BIGINT\"", "targetType" -> "\"INT\"", "details" -> ( s""" @@ -68,7 +69,7 @@ class QueryCompilationErrorsSuite } checkError( exception = e2, - errorClass = "CANNOT_UP_CAST_DATATYPE", + condition = "CANNOT_UP_CAST_DATATYPE", parameters = Map("expression" -> "b.`b`", "sourceType" -> "\"DECIMAL(38,18)\"", "targetType" -> "\"BIGINT\"", "details" -> ( @@ -94,7 +95,7 @@ class QueryCompilationErrorsSuite } checkError( exception = e, - errorClass = "UNSUPPORTED_GROUPING_EXPRESSION", + condition = "UNSUPPORTED_GROUPING_EXPRESSION", parameters = Map[String, String]()) } } @@ -112,7 +113,7 @@ class QueryCompilationErrorsSuite } checkError( exception = e, - errorClass = "UNSUPPORTED_GROUPING_EXPRESSION", + condition = "UNSUPPORTED_GROUPING_EXPRESSION", parameters = Map[String, String]()) } } @@ -123,7 +124,7 @@ class QueryCompilationErrorsSuite exception = intercept[AnalysisException] { sql("select format_string('%0$s', 'Hello')") }, - errorClass = "INVALID_PARAMETER_VALUE.ZERO_INDEX", + condition = "INVALID_PARAMETER_VALUE.ZERO_INDEX", parameters = Map( "parameter" -> "`strfmt`", "functionName" -> "`format_string`"), @@ -156,7 +157,7 @@ class QueryCompilationErrorsSuite checkError( exception = e, - errorClass = "INVALID_PANDAS_UDF_PLACEMENT", + condition = "INVALID_PANDAS_UDF_PLACEMENT", parameters = Map("functionList" -> "`pandas_udf_1`, `pandas_udf_2`")) } @@ -183,7 +184,7 @@ class QueryCompilationErrorsSuite checkError( exception = e, - errorClass = "UNSUPPORTED_FEATURE.PYTHON_UDF_IN_ON_CLAUSE", + condition = "UNSUPPORTED_FEATURE.PYTHON_UDF_IN_ON_CLAUSE", parameters = Map("joinType" -> "LEFT OUTER"), sqlState = Some("0A000")) } @@ -205,7 +206,7 @@ class QueryCompilationErrorsSuite checkError( exception = e, - errorClass = "UNSUPPORTED_FEATURE.PANDAS_UDAF_IN_PIVOT", + condition = "UNSUPPORTED_FEATURE.PANDAS_UDAF_IN_PIVOT", parameters = Map[String, String](), sqlState = "0A000") } @@ -224,7 +225,7 @@ class QueryCompilationErrorsSuite ) checkError( exception = e, - errorClass = "NO_HANDLER_FOR_UDAF", + condition = "NO_HANDLER_FOR_UDAF", parameters = Map("functionName" -> "org.apache.spark.sql.errors.MyCastToString"), context = ExpectedContext( fragment = "myCast(123)", start = 7, stop = 17)) @@ -234,7 +235,7 @@ class QueryCompilationErrorsSuite test("UNTYPED_SCALA_UDF: use untyped Scala UDF should fail by default") { checkError( exception = intercept[AnalysisException](udf((x: Int) => x, IntegerType)), - errorClass = "UNTYPED_SCALA_UDF", + condition = "UNTYPED_SCALA_UDF", parameters = Map[String, String]()) } @@ -248,7 +249,7 @@ class QueryCompilationErrorsSuite ) checkError( exception = e, - errorClass = "NO_UDF_INTERFACE", + condition = "NO_UDF_INTERFACE", parameters = Map("className" -> className)) } @@ -262,7 +263,7 @@ class QueryCompilationErrorsSuite ) checkError( exception = e, - errorClass = "MULTI_UDF_INTERFACE_ERROR", + condition = "MULTI_UDF_INTERFACE_ERROR", parameters = Map("className" -> className)) } @@ -276,7 +277,7 @@ class QueryCompilationErrorsSuite ) checkError( exception = e, - errorClass = "UNSUPPORTED_FEATURE.TOO_MANY_TYPE_ARGUMENTS_FOR_UDF_CLASS", + condition = "UNSUPPORTED_FEATURE.TOO_MANY_TYPE_ARGUMENTS_FOR_UDF_CLASS", parameters = Map("num" -> "24"), sqlState = "0A000") } @@ -287,7 +288,7 @@ class QueryCompilationErrorsSuite } checkError( exception = groupingColMismatchEx, - errorClass = "GROUPING_COLUMN_MISMATCH", + condition = "GROUPING_COLUMN_MISMATCH", parameters = Map("grouping" -> "earnings.*", "groupingColumns" -> "course.*,year.*"), sqlState = Some("42803"), matchPVals = true) @@ -299,7 +300,7 @@ class QueryCompilationErrorsSuite } checkError( exception = groupingIdColMismatchEx, - errorClass = "GROUPING_ID_COLUMN_MISMATCH", + condition = "GROUPING_ID_COLUMN_MISMATCH", parameters = Map("groupingIdColumn" -> "earnings.*", "groupByColumns" -> "course.*,year.*"), sqlState = Some("42803"), @@ -322,14 +323,14 @@ class QueryCompilationErrorsSuite withSQLConf(SQLConf.LEGACY_INTEGER_GROUPING_ID.key -> "true") { checkError( exception = intercept[AnalysisException] { testGroupingIDs(33) }, - errorClass = "GROUPING_SIZE_LIMIT_EXCEEDED", + condition = "GROUPING_SIZE_LIMIT_EXCEEDED", parameters = Map("maxSize" -> "32")) } withSQLConf(SQLConf.LEGACY_INTEGER_GROUPING_ID.key -> "false") { checkError( exception = intercept[AnalysisException] { testGroupingIDs(65) }, - errorClass = "GROUPING_SIZE_LIMIT_EXCEEDED", + condition = "GROUPING_SIZE_LIMIT_EXCEEDED", parameters = Map("maxSize" -> "64")) } } @@ -354,7 +355,7 @@ class QueryCompilationErrorsSuite exception = intercept[AnalysisException] { sql(s"DESC TABLE $tempViewName PARTITION (c='Us', d=1)") }, - errorClass = "FORBIDDEN_OPERATION", + condition = "FORBIDDEN_OPERATION", parameters = Map("statement" -> "DESC PARTITION", "objectType" -> "TEMPORARY VIEW", "objectName" -> s"`$tempViewName`")) } @@ -380,7 +381,7 @@ class QueryCompilationErrorsSuite exception = intercept[AnalysisException] { sql(s"DESC TABLE $viewName PARTITION (c='Us', d=1)") }, - errorClass = "FORBIDDEN_OPERATION", + condition = "FORBIDDEN_OPERATION", parameters = Map("statement" -> "DESC PARTITION", "objectType" -> "VIEW", "objectName" -> s"`$viewName`")) } @@ -394,7 +395,7 @@ class QueryCompilationErrorsSuite exception = intercept[AnalysisException] { sql("select date_add('1982-08-15', 'x')").collect() }, - errorClass = "SECOND_FUNCTION_ARGUMENT_NOT_INTEGER", + condition = "SECOND_FUNCTION_ARGUMENT_NOT_INTEGER", parameters = Map("functionName" -> "date_add"), sqlState = "22023") } @@ -408,7 +409,7 @@ class QueryCompilationErrorsSuite exception = intercept[AnalysisException] { spark.read.schema(schema).json(spark.emptyDataset[String]) }, - errorClass = "INVALID_JSON_SCHEMA_MAP_TYPE", + condition = "INVALID_JSON_SCHEMA_MAP_TYPE", parameters = Map("jsonSchema" -> "\"STRUCT NOT NULL>\"") ) } @@ -418,7 +419,7 @@ class QueryCompilationErrorsSuite val query = "select m[a] from (select map('a', 'b') as m, 'aa' as aa)" checkError( exception = intercept[AnalysisException] {sql(query)}, - errorClass = "UNRESOLVED_MAP_KEY.WITH_SUGGESTION", + condition = "UNRESOLVED_MAP_KEY.WITH_SUGGESTION", sqlState = None, parameters = Map("objectName" -> "`a`", "proposal" -> "`aa`, `m`"), @@ -433,7 +434,7 @@ class QueryCompilationErrorsSuite val query = "select m[a] from (select map('a', 'b') as m, 'aa' as `a.a`)" checkError( exception = intercept[AnalysisException] {sql(query)}, - errorClass = "UNRESOLVED_MAP_KEY.WITH_SUGGESTION", + condition = "UNRESOLVED_MAP_KEY.WITH_SUGGESTION", sqlState = None, parameters = Map( "objectName" -> "`a`", @@ -468,7 +469,7 @@ class QueryCompilationErrorsSuite |order by struct.a, struct.b |""".stripMargin) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map( "objectName" -> "`struct`.`a`", @@ -489,7 +490,7 @@ class QueryCompilationErrorsSuite val query = "SELECT v.i from (SELECT i FROM v)" checkError( exception = intercept[AnalysisException](sql(query)), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map( "objectName" -> "`v`.`i`", @@ -522,7 +523,7 @@ class QueryCompilationErrorsSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "AMBIGUOUS_ALIAS_IN_NESTED_CTE", + condition = "AMBIGUOUS_ALIAS_IN_NESTED_CTE", parameters = Map( "name" -> "`t`", "config" -> toSQLConf(SQLConf.LEGACY_CTE_PRECEDENCE_POLICY.key), @@ -542,7 +543,7 @@ class QueryCompilationErrorsSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "AMBIGUOUS_COLUMN_OR_FIELD", + condition = "AMBIGUOUS_COLUMN_OR_FIELD", parameters = Map("name" -> "`c`.`X`", "n" -> "2"), context = ExpectedContext( fragment = query, start = 0, stop = 52)) @@ -566,7 +567,7 @@ class QueryCompilationErrorsSuite struct(lit("java"), lit("Dummies")))). agg(sum($"earnings")).collect() }, - errorClass = "PIVOT_VALUE_DATA_TYPE_MISMATCH", + condition = "PIVOT_VALUE_DATA_TYPE_MISMATCH", parameters = Map("value" -> "struct(col1, dotnet, col2, Experts)", "valueType" -> "struct", "pivotType" -> "int")) @@ -581,7 +582,7 @@ class QueryCompilationErrorsSuite } checkError( exception = e, - errorClass = "INVALID_FIELD_NAME", + condition = "INVALID_FIELD_NAME", parameters = Map("fieldName" -> "`m`.`n`", "path" -> "`m`"), context = ExpectedContext( fragment = "m.n int", start = 27, stop = 33)) @@ -603,7 +604,7 @@ class QueryCompilationErrorsSuite pivot(df("year"), Seq($"earnings")). agg(sum($"earnings")).collect() }, - errorClass = "NON_LITERAL_PIVOT_VALUES", + condition = "NON_LITERAL_PIVOT_VALUES", parameters = Map("expression" -> "\"earnings\"")) } @@ -613,7 +614,7 @@ class QueryCompilationErrorsSuite } checkError( exception = e, - errorClass = "UNSUPPORTED_DESERIALIZER.DATA_TYPE_MISMATCH", + condition = "UNSUPPORTED_DESERIALIZER.DATA_TYPE_MISMATCH", parameters = Map("desiredType" -> "\"ARRAY\"", "dataType" -> "\"INT\"")) } @@ -626,7 +627,7 @@ class QueryCompilationErrorsSuite } checkError( exception = e1, - errorClass = "UNSUPPORTED_DESERIALIZER.FIELD_NUMBER_MISMATCH", + condition = "UNSUPPORTED_DESERIALIZER.FIELD_NUMBER_MISMATCH", parameters = Map( "schema" -> "\"STRUCT\"", "ordinal" -> "3")) @@ -636,7 +637,7 @@ class QueryCompilationErrorsSuite } checkError( exception = e2, - errorClass = "UNSUPPORTED_DESERIALIZER.FIELD_NUMBER_MISMATCH", + condition = "UNSUPPORTED_DESERIALIZER.FIELD_NUMBER_MISMATCH", parameters = Map("schema" -> "\"STRUCT\"", "ordinal" -> "1")) } @@ -649,7 +650,7 @@ class QueryCompilationErrorsSuite checkError( exception = e, - errorClass = "UNSUPPORTED_GENERATOR.NESTED_IN_EXPRESSIONS", + condition = "UNSUPPORTED_GENERATOR.NESTED_IN_EXPRESSIONS", parameters = Map("expression" -> "\"(explode(array(1, 2, 3)) + 1)\"")) } @@ -660,7 +661,7 @@ class QueryCompilationErrorsSuite checkError( exception = e, - errorClass = "UNSUPPORTED_GENERATOR.OUTSIDE_SELECT", + condition = "UNSUPPORTED_GENERATOR.OUTSIDE_SELECT", parameters = Map("plan" -> "'Sort [explode(array(1, 2, 3)) ASC NULLS FIRST], true")) } @@ -675,7 +676,7 @@ class QueryCompilationErrorsSuite checkError( exception = e, - errorClass = "UNSUPPORTED_GENERATOR.NOT_GENERATOR", + condition = "UNSUPPORTED_GENERATOR.NOT_GENERATOR", sqlState = None, parameters = Map( "functionName" -> "`array_contains`", @@ -690,7 +691,7 @@ class QueryCompilationErrorsSuite exception = intercept[AnalysisException] { Seq("""{"a":1}""").toDF("a").select(from_json($"a", IntegerType)).collect() }, - errorClass = "DATATYPE_MISMATCH.INVALID_JSON_SCHEMA", + condition = "DATATYPE_MISMATCH.INVALID_JSON_SCHEMA", parameters = Map("schema" -> "\"INT\"", "sqlExpr" -> "\"from_json(a)\""), context = ExpectedContext(fragment = "from_json", callSitePattern = getCurrentClassCallSitePattern)) @@ -701,7 +702,7 @@ class QueryCompilationErrorsSuite exception = intercept[AnalysisException] { sql("SELECT CAST(1)") }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> "`cast`", "expectedNum" -> "0", @@ -717,7 +718,7 @@ class QueryCompilationErrorsSuite exception = intercept[ParseException] { sql("CREATE TEMPORARY VIEW db_name.schema_name.view_name AS SELECT '1' as test_column") }, - errorClass = "IDENTIFIER_TOO_MANY_NAME_PARTS", + condition = "IDENTIFIER_TOO_MANY_NAME_PARTS", sqlState = "42601", parameters = Map("identifier" -> "`db_name`.`schema_name`.`view_name`") ) @@ -738,7 +739,7 @@ class QueryCompilationErrorsSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE $tableName RENAME TO db_name.schema_name.new_table_name") }, - errorClass = "IDENTIFIER_TOO_MANY_NAME_PARTS", + condition = "IDENTIFIER_TOO_MANY_NAME_PARTS", sqlState = "42601", parameters = Map("identifier" -> "`db_name`.`schema_name`.`new_table_name`") ) @@ -762,7 +763,7 @@ class QueryCompilationErrorsSuite exception = intercept[AnalysisException] { df.select($"name.firstname") }, - errorClass = "AMBIGUOUS_REFERENCE_TO_FIELDS", + condition = "AMBIGUOUS_REFERENCE_TO_FIELDS", sqlState = "42000", parameters = Map("field" -> "`firstname`", "count" -> "2"), context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern) @@ -776,7 +777,7 @@ class QueryCompilationErrorsSuite exception = intercept[AnalysisException] { df.select($"firstname.test_field") }, - errorClass = "INVALID_EXTRACT_BASE_FIELD_TYPE", + condition = "INVALID_EXTRACT_BASE_FIELD_TYPE", sqlState = "42000", parameters = Map("base" -> "\"firstname\"", "other" -> "\"STRING\""), context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern) @@ -802,7 +803,7 @@ class QueryCompilationErrorsSuite exception = intercept[AnalysisException] { df.select(struct($"name"(struct("test")))) }, - errorClass = "INVALID_EXTRACT_FIELD_TYPE", + condition = "INVALID_EXTRACT_FIELD_TYPE", sqlState = "42000", parameters = Map("extraction" -> "\"struct(test)\"")) @@ -810,7 +811,7 @@ class QueryCompilationErrorsSuite exception = intercept[AnalysisException] { df.select($"name"(array("test"))) }, - errorClass = "INVALID_EXTRACT_FIELD_TYPE", + condition = "INVALID_EXTRACT_FIELD_TYPE", sqlState = "42000", parameters = Map("extraction" -> "\"array(test)\"")) } @@ -831,7 +832,7 @@ class QueryCompilationErrorsSuite exception = intercept[AnalysisException] { sql(query) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "proposal" -> "`c1`, `v1`.`c2`, `v2`.`c2`", "objectName" -> "`b`"), @@ -849,7 +850,7 @@ class QueryCompilationErrorsSuite exception = intercept[SparkIllegalArgumentException] { coalesce.dataType }, - errorClass = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.NO_INPUTS", + condition = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.NO_INPUTS", parameters = Map("expression" -> "\"coalesce()\"")) } @@ -861,7 +862,7 @@ class QueryCompilationErrorsSuite exception = intercept[SparkIllegalArgumentException] { coalesce.dataType }, - errorClass = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.MISMATCHED_TYPES", + condition = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.MISMATCHED_TYPES", parameters = Map( "expression" -> "\"coalesce(1, a, a)\"", "inputTypes" -> "[\"INT\", \"STRING\", \"STRING\"]")) @@ -872,7 +873,7 @@ class QueryCompilationErrorsSuite exception = intercept[SparkUnsupportedOperationException] { new UnsafeRow(1).update(0, 1) }, - errorClass = "UNSUPPORTED_CALL.WITHOUT_SUGGESTION", + condition = "UNSUPPORTED_CALL.WITHOUT_SUGGESTION", parameters = Map( "methodName" -> "update", "className" -> "org.apache.spark.sql.catalyst.expressions.UnsafeRow")) @@ -891,7 +892,7 @@ class QueryCompilationErrorsSuite exception = intercept[AnalysisException] { converter.convertField(StructField("test", dummyDataType)) }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "Cannot convert Spark data type \"DUMMY\" to any Parquet type.") ) } @@ -919,13 +920,32 @@ class QueryCompilationErrorsSuite exception = intercept[AnalysisException] { sql(test.query) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`dummy`", "proposal" -> "`a`, `b`"), context = ExpectedContext(fragment = "dummy", start = test.pos, stop = test.pos + 4) ) }) } } + + test("Catch and log errors when failing to write to external data source") { + val password = "MyPassWord" + val token = "MyToken" + val value = "value" + val options = Map("password" -> password, "token" -> token, "key" -> value) + val query = spark.range(10).logicalPlan + val cmd = SaveIntoDataSourceCommand(query, null, options, SaveMode.Overwrite) + + checkError( + exception = intercept[AnalysisException] { + cmd.run(spark) + }, + condition = "DATA_SOURCE_EXTERNAL_ERROR", + sqlState = "KD010", + parameters = Map.empty + ) + } + } class MyCastToString extends SparkUserDefinedFunction( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala index 83495e0670240..ec92e0b700e31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala @@ -45,11 +45,10 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkArithmeticException] { sql("select CAST(TIMESTAMP '9999-12-31T12:13:14.56789Z' AS INT)").collect() }, - errorClass = "CAST_OVERFLOW", + condition = "CAST_OVERFLOW", parameters = Map("value" -> "TIMESTAMP '9999-12-31 04:13:14.56789'", "sourceType" -> "\"TIMESTAMP\"", - "targetType" -> "\"INT\"", - "ansiConfig" -> ansiConf), + "targetType" -> "\"INT\""), sqlState = "22003") } @@ -58,7 +57,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkArithmeticException] { sql("select 6/0").collect() }, - errorClass = "DIVIDE_BY_ZERO", + condition = "DIVIDE_BY_ZERO", sqlState = "22012", parameters = Map("config" -> ansiConf), context = ExpectedContext(fragment = "6/0", start = 7, stop = 9)) @@ -67,7 +66,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkArithmeticException] { OneRowRelation().select(lit(5) / lit(0)).collect() }, - errorClass = "DIVIDE_BY_ZERO", + condition = "DIVIDE_BY_ZERO", sqlState = "22012", parameters = Map("config" -> ansiConf), context = ExpectedContext(fragment = "div", callSitePattern = getCurrentClassCallSitePattern)) @@ -76,7 +75,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkArithmeticException] { OneRowRelation().select(lit(5).divide(lit(0))).collect() }, - errorClass = "DIVIDE_BY_ZERO", + condition = "DIVIDE_BY_ZERO", sqlState = "22012", parameters = Map("config" -> ansiConf), context = ExpectedContext( @@ -89,7 +88,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkArithmeticException] { sql("select interval 1 day / 0").collect() }, - errorClass = "INTERVAL_DIVIDED_BY_ZERO", + condition = "INTERVAL_DIVIDED_BY_ZERO", sqlState = "22012", parameters = Map.empty[String, String], context = ExpectedContext(fragment = "interval 1 day / 0", start = 7, stop = 24)) @@ -100,7 +99,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkDateTimeException] { sql("select make_timestamp(2012, 11, 30, 9, 19, 60.66666666)").collect() }, - errorClass = "INVALID_FRACTION_OF_SECOND", + condition = "INVALID_FRACTION_OF_SECOND", sqlState = "22023", parameters = Map("ansiConfig" -> ansiConf)) } @@ -110,7 +109,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkArithmeticException] { sql("select CAST('66666666666666.666' AS DECIMAL(8, 1))").collect() }, - errorClass = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION", + condition = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION", sqlState = "22003", parameters = Map( "value" -> "66666666666666.666", @@ -126,7 +125,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkArithmeticException] { OneRowRelation().select(lit("66666666666666.666").cast("DECIMAL(8, 1)")).collect() }, - errorClass = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION", + condition = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION", sqlState = "22003", parameters = Map( "value" -> "66666666666666.666", @@ -143,7 +142,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkArrayIndexOutOfBoundsException] { sql("select array(1, 2, 3, 4, 5)[8]").collect() }, - errorClass = "INVALID_ARRAY_INDEX", + condition = "INVALID_ARRAY_INDEX", parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), context = ExpectedContext(fragment = "array(1, 2, 3, 4, 5)[8]", start = 7, stop = 29)) @@ -151,7 +150,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkArrayIndexOutOfBoundsException] { OneRowRelation().select(lit(Array(1, 2, 3, 4, 5))(8)).collect() }, - errorClass = "INVALID_ARRAY_INDEX", + condition = "INVALID_ARRAY_INDEX", parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), context = ExpectedContext( fragment = "apply", @@ -163,7 +162,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkArrayIndexOutOfBoundsException] { sql("select element_at(array(1, 2, 3, 4, 5), 8)").collect() }, - errorClass = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT", + condition = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT", parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), context = ExpectedContext( fragment = "element_at(array(1, 2, 3, 4, 5), 8)", @@ -174,7 +173,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkArrayIndexOutOfBoundsException] { OneRowRelation().select(element_at(lit(Array(1, 2, 3, 4, 5)), 8)).collect() }, - errorClass = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT", + condition = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT", parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), context = ExpectedContext(fragment = "element_at", callSitePattern = getCurrentClassCallSitePattern)) @@ -185,7 +184,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkRuntimeException]( sql("select element_at(array(1, 2, 3, 4, 5), 0)").collect() ), - errorClass = "INVALID_INDEX_OF_ZERO", + condition = "INVALID_INDEX_OF_ZERO", parameters = Map.empty, context = ExpectedContext( fragment = "element_at(array(1, 2, 3, 4, 5), 0)", @@ -197,7 +196,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkRuntimeException]( OneRowRelation().select(element_at(lit(Array(1, 2, 3, 4, 5)), 0)).collect() ), - errorClass = "INVALID_INDEX_OF_ZERO", + condition = "INVALID_INDEX_OF_ZERO", parameters = Map.empty, context = ExpectedContext(fragment = "element_at", callSitePattern = getCurrentClassCallSitePattern)) @@ -208,12 +207,11 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkNumberFormatException] { sql("select CAST('111111111111xe23' AS DOUBLE)").collect() }, - errorClass = "CAST_INVALID_INPUT", + condition = "CAST_INVALID_INPUT", parameters = Map( "expression" -> "'111111111111xe23'", "sourceType" -> "\"STRING\"", - "targetType" -> "\"DOUBLE\"", - "ansiConfig" -> ansiConf), + "targetType" -> "\"DOUBLE\""), context = ExpectedContext( fragment = "CAST('111111111111xe23' AS DOUBLE)", start = 7, @@ -223,12 +221,11 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkNumberFormatException] { OneRowRelation().select(lit("111111111111xe23").cast("DOUBLE")).collect() }, - errorClass = "CAST_INVALID_INPUT", + condition = "CAST_INVALID_INPUT", parameters = Map( "expression" -> "'111111111111xe23'", "sourceType" -> "\"STRING\"", - "targetType" -> "\"DOUBLE\"", - "ansiConfig" -> ansiConf), + "targetType" -> "\"DOUBLE\""), context = ExpectedContext( fragment = "cast", callSitePattern = getCurrentClassCallSitePattern)) @@ -239,7 +236,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkDateTimeException] { sql("select to_timestamp('abc', 'yyyy-MM-dd HH:mm:ss')").collect() }, - errorClass = "CANNOT_PARSE_TIMESTAMP", + condition = "CANNOT_PARSE_TIMESTAMP", parameters = Map( "message" -> "Text 'abc' could not be parsed at index 0", "ansiConfig" -> ansiConf) @@ -255,7 +252,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkArithmeticException] { sql(s"insert into $tableName values 12345678901234567890D") }, - errorClass = "CAST_OVERFLOW_IN_TABLE_INSERT", + condition = "CAST_OVERFLOW_IN_TABLE_INSERT", parameters = Map( "sourceType" -> "\"DOUBLE\"", "targetType" -> ("\"" + targetType + "\""), @@ -272,11 +269,10 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkArithmeticException] { CheckOverflowInTableInsert(caseWhen, "col").eval(null) }, - errorClass = "CAST_OVERFLOW", + condition = "CAST_OVERFLOW", parameters = Map("value" -> "1.2345678901234567E19D", "sourceType" -> "\"DOUBLE\"", - "targetType" -> ("\"TINYINT\""), - "ansiConfig" -> ansiConf) + "targetType" -> ("\"TINYINT\"")) ) } @@ -291,11 +287,10 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkArithmeticException] { sql(insertCmd).collect() }, - errorClass = "CAST_OVERFLOW", + condition = "CAST_OVERFLOW", parameters = Map("value" -> "-1.2345678901234567E19D", "sourceType" -> "\"DOUBLE\"", - "targetType" -> "\"TINYINT\"", - "ansiConfig" -> ansiConf), + "targetType" -> "\"TINYINT\""), sqlState = "22003") } } @@ -306,7 +301,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest CheckOverflowInTableInsert( Cast(Literal.apply(12345678901234567890D), ByteType), "col").eval(null) }.asInstanceOf[SparkThrowable], - errorClass = "CAST_OVERFLOW_IN_TABLE_INSERT", + condition = "CAST_OVERFLOW_IN_TABLE_INSERT", parameters = Map( "sourceType" -> "\"DOUBLE\"", "targetType" -> ("\"TINYINT\""), @@ -322,7 +317,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest exception = intercept[SparkArithmeticException] { CheckOverflowInTableInsert(proxy, "col").eval(null) }.asInstanceOf[SparkThrowable], - errorClass = "CAST_OVERFLOW_IN_TABLE_INSERT", + condition = "CAST_OVERFLOW_IN_TABLE_INSERT", parameters = Map( "sourceType" -> "\"DOUBLE\"", "targetType" -> ("\"TINYINT\""), @@ -366,7 +361,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest checkError( // If error is user-facing, it will be thrown directly. exception = intercept[SparkArithmeticException](df3.collect()), - errorClass = "DIVIDE_BY_ZERO", + condition = "DIVIDE_BY_ZERO", parameters = Map("config" -> ansiConf), context = ExpectedContext( fragment = "div", @@ -381,7 +376,7 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest val df4 = spark.range(0, 10, 1, 2).select(lit(1) / $"id") checkError( exception = intercept[SparkArithmeticException](df4.collect()), - errorClass = "DIVIDE_BY_ZERO", + condition = "DIVIDE_BY_ZERO", parameters = Map("config" -> ansiConf), context = ExpectedContext( fragment = "div", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index 349b124970e32..00dfd3451d577 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -77,7 +77,7 @@ class QueryExecutionErrorsSuite } checkError( exception, - errorClass = "CONVERSION_INVALID_INPUT", + condition = "CONVERSION_INVALID_INPUT", parameters = Map( "str" -> "'???'", "fmt" -> "'BASE64'", @@ -95,7 +95,7 @@ class QueryExecutionErrorsSuite } checkError( exception, - errorClass = "CONVERSION_INVALID_INPUT", + condition = "CONVERSION_INVALID_INPUT", parameters = Map( "str" -> "'???'", "fmt" -> "'HEX'", @@ -129,7 +129,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkRuntimeException] { df.collect() }, - errorClass = "INVALID_PARAMETER_VALUE.AES_KEY_LENGTH", + condition = "INVALID_PARAMETER_VALUE.AES_KEY_LENGTH", parameters = Map( "parameter" -> "`key`", "functionName" -> "`aes_encrypt`/`aes_decrypt`", @@ -166,7 +166,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkRuntimeException] { df2.selectExpr(s"aes_decrypt(unbase64($colName), binary('$key'), 'ECB')").collect() }, - errorClass = "INVALID_PARAMETER_VALUE.AES_CRYPTO_ERROR", + condition = "INVALID_PARAMETER_VALUE.AES_CRYPTO_ERROR", parameters = Map("parameter" -> "`expr`, `key`", "functionName" -> "`aes_encrypt`/`aes_decrypt`", "detailMessage" -> ("Given final block not properly padded. " + @@ -184,7 +184,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkRuntimeException] { df.collect() }, - errorClass = "UNSUPPORTED_FEATURE.AES_MODE", + condition = "UNSUPPORTED_FEATURE.AES_MODE", parameters = Map("mode" -> mode, "padding" -> padding, "functionName" -> "`aes_encrypt`/`aes_decrypt`"), @@ -212,7 +212,7 @@ class QueryExecutionErrorsSuite def checkUnsupportedTypeInLiteral(v: Any, literal: String, dataType: String): Unit = { checkError( exception = intercept[SparkRuntimeException] { spark.expression(lit(v)) }, - errorClass = "UNSUPPORTED_FEATURE.LITERAL_TYPE", + condition = "UNSUPPORTED_FEATURE.LITERAL_TYPE", parameters = Map("value" -> literal, "type" -> dataType), sqlState = "0A000") } @@ -230,7 +230,7 @@ class QueryExecutionErrorsSuite } checkError( exception = e2, - errorClass = "UNSUPPORTED_FEATURE.PIVOT_TYPE", + condition = "UNSUPPORTED_FEATURE.PIVOT_TYPE", parameters = Map("value" -> "[dotnet,Dummies]", "type" -> "unknown"), sqlState = "0A000") @@ -247,7 +247,7 @@ class QueryExecutionErrorsSuite } checkError( exception = e1, - errorClass = "REPEATED_CLAUSE", + condition = "REPEATED_CLAUSE", parameters = Map("clause" -> "PIVOT", "operation" -> "SUBQUERY"), sqlState = "42614") @@ -260,7 +260,7 @@ class QueryExecutionErrorsSuite } checkError( exception = e2, - errorClass = "UNSUPPORTED_FEATURE.PIVOT_AFTER_GROUP_BY", + condition = "UNSUPPORTED_FEATURE.PIVOT_AFTER_GROUP_BY", parameters = Map[String, String](), sqlState = "0A000") } @@ -281,7 +281,7 @@ class QueryExecutionErrorsSuite val option = "\"datetimeRebaseMode\"" checkError( exception = e, - errorClass = "INCONSISTENT_BEHAVIOR_CROSS_VERSION.READ_ANCIENT_DATETIME", + condition = "INCONSISTENT_BEHAVIOR_CROSS_VERSION.READ_ANCIENT_DATETIME", parameters = Map("format" -> format, "config" -> config, "option" -> option)) } @@ -298,7 +298,7 @@ class QueryExecutionErrorsSuite val config = "\"" + SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key + "\"" checkError( exception = e.getCause.asInstanceOf[SparkUpgradeException], - errorClass = "INCONSISTENT_BEHAVIOR_CROSS_VERSION.WRITE_ANCIENT_DATETIME", + condition = "INCONSISTENT_BEHAVIOR_CROSS_VERSION.WRITE_ANCIENT_DATETIME", parameters = Map("format" -> format, "config" -> config)) } } @@ -314,7 +314,7 @@ class QueryExecutionErrorsSuite assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[SparkUnsupportedOperationException], - errorClass = "UNSUPPORTED_FEATURE.ORC_TYPE_CAST", + condition = "UNSUPPORTED_FEATURE.ORC_TYPE_CAST", parameters = Map("orcType" -> "\"TIMESTAMP\"", "toType" -> "\"TIMESTAMP_NTZ\""), sqlState = "0A000") @@ -336,7 +336,7 @@ class QueryExecutionErrorsSuite assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[SparkUnsupportedOperationException], - errorClass = "UNSUPPORTED_FEATURE.ORC_TYPE_CAST", + condition = "UNSUPPORTED_FEATURE.ORC_TYPE_CAST", parameters = Map("orcType" -> "\"TIMESTAMP_NTZ\"", "toType" -> "\"TIMESTAMP\""), sqlState = "0A000") @@ -349,7 +349,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkArithmeticException] { sql("select timestampadd(YEAR, 1000000, timestamp'2022-03-09 01:02:03')").collect() }, - errorClass = "DATETIME_OVERFLOW", + condition = "DATETIME_OVERFLOW", parameters = Map("operation" -> "add 1000000 YEAR to TIMESTAMP '2022-03-09 01:02:03'"), sqlState = "22008") } @@ -385,7 +385,7 @@ class QueryExecutionErrorsSuite checkError( exception = e2.getCause.asInstanceOf[SparkRuntimeException], - errorClass = "CANNOT_PARSE_DECIMAL", + condition = "CANNOT_PARSE_DECIMAL", parameters = Map[String, String](), sqlState = "22018") } @@ -397,7 +397,7 @@ class QueryExecutionErrorsSuite sql(s"SELECT from_json('$jsonStr', 'a INT, b DOUBLE', map('mode','FAILFAST') )") .collect() }, - errorClass = "MALFORMED_RECORD_IN_PARSING.CANNOT_PARSE_JSON_ARRAYS_AS_STRUCTS", + condition = "MALFORMED_RECORD_IN_PARSING.CANNOT_PARSE_JSON_ARRAYS_AS_STRUCTS", parameters = Map( "badRecord" -> jsonStr, "failFastMode" -> "FAILFAST" @@ -429,7 +429,7 @@ class QueryExecutionErrorsSuite .createOrReplaceTempView("words") spark.sql("select luckyCharOfWord(word, index) from words").collect() }, - errorClass = "FAILED_EXECUTE_UDF", + condition = "FAILED_EXECUTE_UDF", parameters = Map( "functionName" -> functionNameRegex, "signature" -> "string, int", @@ -458,7 +458,7 @@ class QueryExecutionErrorsSuite val words = Seq(("Jacek", 5), ("Agata", 5), ("Sweet", 6)).toDF("word", "index") words.select(luckyCharOfWord($"word", $"index")).collect() }, - errorClass = "FAILED_EXECUTE_UDF", + condition = "FAILED_EXECUTE_UDF", parameters = Map("functionName" -> functionNameRegex, "signature" -> "string, int", "result" -> "string", @@ -487,7 +487,7 @@ class QueryExecutionErrorsSuite } checkError( exception = e, - errorClass = "INCOMPARABLE_PIVOT_COLUMN", + condition = "INCOMPARABLE_PIVOT_COLUMN", parameters = Map("columnName" -> "`map`"), sqlState = "42818") } @@ -500,7 +500,7 @@ class QueryExecutionErrorsSuite } checkError( exception = e1, - errorClass = "UNSUPPORTED_SAVE_MODE.NON_EXISTENT_PATH", + condition = "UNSUPPORTED_SAVE_MODE.NON_EXISTENT_PATH", parameters = Map("saveMode" -> "NULL")) Utils.createDirectory(path) @@ -511,7 +511,7 @@ class QueryExecutionErrorsSuite } checkError( exception = e2, - errorClass = "UNSUPPORTED_SAVE_MODE.EXISTENT_PATH", + condition = "UNSUPPORTED_SAVE_MODE.EXISTENT_PATH", parameters = Map("saveMode" -> "NULL")) } } @@ -521,7 +521,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkException] { RuleIdCollection.getRuleId("incorrect") }, - errorClass = "RULE_ID_NOT_FOUND", + condition = "RULE_ID_NOT_FOUND", parameters = Map("ruleName" -> "incorrect") ) } @@ -540,7 +540,7 @@ class QueryExecutionErrorsSuite checkError( exception = e.getCause.asInstanceOf[SparkSecurityException], - errorClass = "CANNOT_RESTORE_PERMISSIONS_FOR_PATH", + condition = "CANNOT_RESTORE_PERMISSIONS_FOR_PATH", parameters = Map("permission" -> ".+", "path" -> ".+"), matchPVals = true) @@ -569,7 +569,7 @@ class QueryExecutionErrorsSuite } checkError( exception = e, - errorClass = "INCOMPATIBLE_DATASOURCE_REGISTER", + condition = "INCOMPATIBLE_DATASOURCE_REGISTER", parameters = Map("message" -> ("Illegal configuration-file syntax: " + "META-INF/services/org.apache.spark.sql.sources.DataSourceRegister"))) } @@ -650,7 +650,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkSQLException] { spark.read.jdbc(urlWithUserAndPass, tableName, new Properties()).collect() }, - errorClass = "UNRECOGNIZED_SQL_TYPE", + condition = "UNRECOGNIZED_SQL_TYPE", parameters = Map("typeName" -> unrecognizedColumnTypeName, "jdbcType" -> "DATALINK"), sqlState = "42704") @@ -675,7 +675,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkException] { aggregated.count() }, - errorClass = "INVALID_BUCKET_FILE", + condition = "INVALID_BUCKET_FILE", parameters = Map("path" -> ".+"), matchPVals = true) } @@ -688,7 +688,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkException] { sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect() }, - errorClass = "SCALAR_SUBQUERY_TOO_MANY_ROWS", + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS", queryContext = Array( ExpectedContext( fragment = "(select a from (select 1 as a union all select 2 as a) t)", @@ -704,7 +704,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkArithmeticException]( sql("select add_months('5500000-12-31', 10000000)").collect() ), - errorClass = "ARITHMETIC_OVERFLOW", + condition = "ARITHMETIC_OVERFLOW", parameters = Map( "message" -> "integer overflow", "alternative" -> "", @@ -717,7 +717,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkRuntimeException] { StructType.fromString(raw) }, - errorClass = "FAILED_PARSE_STRUCT_TYPE", + condition = "FAILED_PARSE_STRUCT_TYPE", parameters = Map("raw" -> s"'$raw'")) } @@ -730,12 +730,11 @@ class QueryExecutionErrorsSuite exception = intercept[SparkArithmeticException] { sql(s"select CAST($sourceValue AS $it)").collect() }, - errorClass = "CAST_OVERFLOW", + condition = "CAST_OVERFLOW", parameters = Map( "value" -> sourceValue, "sourceType" -> s""""${sourceType.sql}"""", - "targetType" -> s""""$it"""", - "ansiConfig" -> s""""${SQLConf.ANSI_ENABLED.key}""""), + "targetType" -> s""""$it""""), sqlState = "22003") } } @@ -747,7 +746,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkRuntimeException] { sql(s"""SELECT from_json('$jsonStr', 'a FLOAT', map('mode','FAILFAST'))""").collect() }, - errorClass = "MALFORMED_RECORD_IN_PARSING.CANNOT_PARSE_STRING_AS_DATATYPE", + condition = "MALFORMED_RECORD_IN_PARSING.CANNOT_PARSE_STRING_AS_DATATYPE", parameters = Map( "badRecord" -> jsonStr, "failFastMode" -> "FAILFAST", @@ -764,7 +763,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkArithmeticException] { sql(s"select 127Y + 5Y").collect() }, - errorClass = "BINARY_ARITHMETIC_OVERFLOW", + condition = "BINARY_ARITHMETIC_OVERFLOW", parameters = Map( "value1" -> "127S", "symbol" -> "+", @@ -779,7 +778,7 @@ class QueryExecutionErrorsSuite val row = spark.sparkContext.parallelize(Seq(1, 2)).map(Row(_)) spark.createDataFrame(row, StructType.fromString("StructType()")) }, - errorClass = "UNSUPPORTED_DATATYPE", + condition = "UNSUPPORTED_DATATYPE", parameters = Map( "typeName" -> "StructType()[1.1] failure: 'TimestampType' expected but 'S' found\n\nStructType()\n^" @@ -810,7 +809,7 @@ class QueryExecutionErrorsSuite val expectedPath = p.toURI checkError( exception = e, - errorClass = "RENAME_SRC_PATH_NOT_FOUND", + condition = "RENAME_SRC_PATH_NOT_FOUND", matchPVals = true, parameters = Map("sourcePath" -> s"$expectedPath.+") ) @@ -871,7 +870,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkSQLFeatureNotSupportedException] { sql("alter TABLE h2.test.people SET TBLPROPERTIES (xx='xx', yy='yy')") }, - errorClass = "UNSUPPORTED_FEATURE.MULTI_ACTION_ALTER", + condition = "UNSUPPORTED_FEATURE.MULTI_ACTION_ALTER", parameters = Map("tableName" -> "\"test\".\"people\"")) JdbcDialects.unregisterDialect(testH2DialectUnsupportedJdbcTransaction) @@ -927,7 +926,7 @@ class QueryExecutionErrorsSuite exceptions.flatten.map { e => checkError( e, - errorClass = "CONCURRENT_QUERY", + condition = "CONCURRENT_QUERY", sqlState = Some("0A000"), parameters = e.getMessageParameters.asScala.toMap ) @@ -948,7 +947,7 @@ class QueryExecutionErrorsSuite checkError( exception = e, - errorClass = "UNSUPPORTED_EXPR_FOR_WINDOW", + condition = "UNSUPPORTED_EXPR_FOR_WINDOW", parameters = Map( "sqlExpr" -> "\"to_date(2009-07-30 04:17:52)\"" ), @@ -969,7 +968,7 @@ class QueryExecutionErrorsSuite } checkError( exception = e, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "Cannot evaluate expression: namedparameter(foo)"), sqlState = "XX000") } @@ -981,7 +980,7 @@ class QueryExecutionErrorsSuite } checkError( exception = e, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map( "message" -> ("Cannot generate code for expression: " + "grouping(namedparameter(foo))")), @@ -994,7 +993,7 @@ class QueryExecutionErrorsSuite } checkError( exception = e, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "Cannot terminate expression: 'foo()"), sqlState = "XX000") } @@ -1008,7 +1007,7 @@ class QueryExecutionErrorsSuite } checkError( exception = e, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map( "message" -> ("""A method named "nonexistent" is not declared in """ + "any enclosing class nor any supertype")), @@ -1021,7 +1020,7 @@ class QueryExecutionErrorsSuite } checkError( exception = e, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map( "message" -> "The aggregate window function `row_number` does not support merging."), sqlState = "XX000") @@ -1032,7 +1031,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkArrayIndexOutOfBoundsException] { sql("select bitmap_construct_agg(col) from values (32768) as tab(col)").collect() }, - errorClass = "INVALID_BITMAP_POSITION", + condition = "INVALID_BITMAP_POSITION", parameters = Map( "bitPosition" -> "32768", "bitmapNumBytes" -> "4096", @@ -1045,7 +1044,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkArrayIndexOutOfBoundsException] { sql("select bitmap_construct_agg(col) from values (-1) as tab(col)").collect() }, - errorClass = "INVALID_BITMAP_POSITION", + condition = "INVALID_BITMAP_POSITION", parameters = Map( "bitPosition" -> "-1", "bitmapNumBytes" -> "4096", @@ -1060,7 +1059,7 @@ class QueryExecutionErrorsSuite maxBroadcastTableBytes = 1024 * 1024 * 1024, dataSize = 2 * 1024 * 1024 * 1024 - 1) }, - errorClass = "_LEGACY_ERROR_TEMP_2249", + condition = "_LEGACY_ERROR_TEMP_2249", parameters = Map("maxBroadcastTableBytes" -> "1024.0 MiB", "dataSize" -> "2048.0 MiB")) } @@ -1071,7 +1070,7 @@ class QueryExecutionErrorsSuite exception = intercept[AnalysisException] { sql("SELECT * FROM t TIMESTAMP AS OF '2021-01-29 00:00:00'").collect() }, - errorClass = "UNSUPPORTED_FEATURE.TIME_TRAVEL", + condition = "UNSUPPORTED_FEATURE.TIME_TRAVEL", parameters = Map("relationId" -> "`spark_catalog`.`default`.`t`") ) } @@ -1082,7 +1081,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkRuntimeException] { sql("select slice(array(1,2,3), 0, 1)").collect() }, - errorClass = "INVALID_PARAMETER_VALUE.START", + condition = "INVALID_PARAMETER_VALUE.START", parameters = Map( "parameter" -> toSQLId("start"), "functionName" -> toSQLId("slice") @@ -1095,7 +1094,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkRuntimeException] { sql("select slice(array(1,2,3), 1, -1)").collect() }, - errorClass = "INVALID_PARAMETER_VALUE.LENGTH", + condition = "INVALID_PARAMETER_VALUE.LENGTH", parameters = Map( "parameter" -> toSQLId("length"), "length" -> (-1).toString, @@ -1112,7 +1111,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkRuntimeException] { Concat(Seq(Literal.create(array, ArrayType(BooleanType)))).eval(EmptyRow) }, - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.FUNCTION", + condition = "COLLECTION_SIZE_LIMIT_EXCEEDED.FUNCTION", parameters = Map( "numberOfElements" -> Int.MaxValue.toString, "maxRoundedArrayLength" -> MAX_ROUNDED_ARRAY_LENGTH.toString, @@ -1129,7 +1128,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkRuntimeException] { Flatten(CreateArray(Seq(Literal.create(array, ArrayType(BooleanType))))).eval(EmptyRow) }, - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.FUNCTION", + condition = "COLLECTION_SIZE_LIMIT_EXCEEDED.FUNCTION", parameters = Map( "numberOfElements" -> Int.MaxValue.toString, "maxRoundedArrayLength" -> MAX_ROUNDED_ARRAY_LENGTH.toString, @@ -1144,7 +1143,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkRuntimeException] { sql(s"select array_repeat(1, $count)").collect() }, - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", + condition = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", parameters = Map( "parameter" -> toSQLId("count"), "numberOfElements" -> count.toString, @@ -1164,7 +1163,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkRuntimeException] { Seq(KryoData(1), KryoData(2)).toDS() }, - errorClass = "INVALID_EXPRESSION_ENCODER", + condition = "INVALID_EXPRESSION_ENCODER", parameters = Map( "encoderType" -> kryoEncoder.getClass.getName, "docroot" -> SPARK_DOC_ROOT @@ -1177,13 +1176,13 @@ class QueryExecutionErrorsSuite val enc: ExpressionEncoder[Row] = ExpressionEncoder(rowEnc) val deserializer = AttributeReference.apply("v", IntegerType)() implicit val im: ExpressionEncoder[Row] = new ExpressionEncoder[Row]( - enc.objSerializer, deserializer, enc.clsTag) + rowEnc, enc.objSerializer, deserializer) checkError( exception = intercept[SparkRuntimeException] { spark.createDataset(Seq(Row(1))).collect() }, - errorClass = "NOT_UNRESOLVED_ENCODER", + condition = "NOT_UNRESOLVED_ENCODER", parameters = Map( "attr" -> deserializer.toString ) @@ -1206,7 +1205,7 @@ class QueryExecutionErrorsSuite exception = intercept[SparkRuntimeException] { expr.eval(EmptyRow) }, - errorClass = "CLASS_NOT_OVERRIDE_EXPECTED_METHOD", + condition = "CLASS_NOT_OVERRIDE_EXPECTED_METHOD", parameters = Map( "className" -> expr.getClass.getName, "method1" -> "eval", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala index b7fb65091ef73..da7b6e7f63c85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala @@ -39,7 +39,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL } checkError( exception = parseException(query), - errorClass = "FAILED_TO_PARSE_TOO_COMPLEX", + condition = "FAILED_TO_PARSE_TOO_COMPLEX", parameters = Map(), context = ExpectedContext( query, @@ -53,7 +53,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL ", 2 as first, 3 as second, 4 as second, 5 as third" checkError( exception = parseException(query), - errorClass = "EXEC_IMMEDIATE_DUPLICATE_ARGUMENT_ALIASES", + condition = "EXEC_IMMEDIATE_DUPLICATE_ARGUMENT_ALIASES", parameters = Map("aliases" -> "`second`, `first`"), context = ExpectedContext( "USING 1 as first, 2 as first, 3 as second, 4 as second, 5 as third", @@ -66,7 +66,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL val query = "EXECUTE IMMEDIATE 'SELCT 1707 WHERE ? = 1' INTO a USING 1" checkError( exception = parseException(query), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'SELCT'", "hint" -> ""), context = ExpectedContext( start = 0, @@ -79,7 +79,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL withSQLConf("spark.sql.allowNamedFunctionArguments" -> "false") { checkError( exception = parseException("SELECT explode(arr => array(10, 20))"), - errorClass = "NAMED_PARAMETER_SUPPORT_DISABLED", + condition = "NAMED_PARAMETER_SUPPORT_DISABLED", parameters = Map("functionName"-> toSQLId("explode"), "argument" -> toSQLId("arr")) ) } @@ -88,7 +88,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL test("UNSUPPORTED_FEATURE: LATERAL join with NATURAL join not supported") { checkError( exception = parseException("SELECT * FROM t1 NATURAL JOIN LATERAL (SELECT c1 + c2 AS c2)"), - errorClass = "INCOMPATIBLE_JOIN_TYPES", + condition = "INCOMPATIBLE_JOIN_TYPES", parameters = Map("joinType1" -> "LATERAL", "joinType2" -> "NATURAL"), sqlState = "42613", context = ExpectedContext( @@ -100,7 +100,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL test("UNSUPPORTED_FEATURE: LATERAL join with USING join not supported") { checkError( exception = parseException("SELECT * FROM t1 JOIN LATERAL (SELECT c1 + c2 AS c2) USING (c2)"), - errorClass = "UNSUPPORTED_FEATURE.LATERAL_JOIN_USING", + condition = "UNSUPPORTED_FEATURE.LATERAL_JOIN_USING", sqlState = "0A000", context = ExpectedContext( fragment = "JOIN LATERAL (SELECT c1 + c2 AS c2) USING (c2)", @@ -116,7 +116,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL "LEFT ANTI" -> (17, 72)).foreach { case (joinType, (start, stop)) => checkError( exception = parseException(s"SELECT * FROM t1 $joinType JOIN LATERAL (SELECT c1 + c2 AS c3) ON c2 = c3"), - errorClass = "INVALID_LATERAL_JOIN_TYPE", + condition = "INVALID_LATERAL_JOIN_TYPE", parameters = Map("joinType" -> joinType), context = ExpectedContext( fragment = s"$joinType JOIN LATERAL (SELECT c1 + c2 AS c3) ON c2 = c3", @@ -136,7 +136,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL ).foreach { case (sqlText, (fragment, start, stop)) => checkError( exception = parseException(s"SELECT * FROM t1$sqlText"), - errorClass = "INVALID_SQL_SYNTAX.LATERAL_WITHOUT_SUBQUERY_OR_TABLE_VALUED_FUNC", + condition = "INVALID_SQL_SYNTAX.LATERAL_WITHOUT_SUBQUERY_OR_TABLE_VALUED_FUNC", sqlState = "42000", context = ExpectedContext(fragment, start, stop)) } @@ -145,7 +145,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL test("UNSUPPORTED_FEATURE: NATURAL CROSS JOIN is not supported") { checkError( exception = parseException("SELECT * FROM a NATURAL CROSS JOIN b"), - errorClass = "INCOMPATIBLE_JOIN_TYPES", + condition = "INCOMPATIBLE_JOIN_TYPES", parameters = Map("joinType1" -> "NATURAL", "joinType2" -> "CROSS"), sqlState = "42613", context = ExpectedContext( @@ -157,7 +157,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL test("INVALID_SQL_SYNTAX.REPETITIVE_WINDOW_DEFINITION: redefine window") { checkError( exception = parseException("SELECT min(a) OVER win FROM t1 WINDOW win AS win, win AS win2"), - errorClass = "INVALID_SQL_SYNTAX.REPETITIVE_WINDOW_DEFINITION", + condition = "INVALID_SQL_SYNTAX.REPETITIVE_WINDOW_DEFINITION", sqlState = "42000", parameters = Map("windowName" -> "`win`"), context = ExpectedContext( @@ -169,7 +169,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL test("INVALID_SQL_SYNTAX.INVALID_WINDOW_REFERENCE: invalid window reference") { checkError( exception = parseException("SELECT min(a) OVER win FROM t1 WINDOW win AS win"), - errorClass = "INVALID_SQL_SYNTAX.INVALID_WINDOW_REFERENCE", + condition = "INVALID_SQL_SYNTAX.INVALID_WINDOW_REFERENCE", sqlState = "42000", parameters = Map("windowName" -> "`win`"), context = ExpectedContext( @@ -181,7 +181,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL test("INVALID_SQL_SYNTAX.UNRESOLVED_WINDOW_REFERENCE: window reference cannot be resolved") { checkError( exception = parseException("SELECT min(a) OVER win FROM t1 WINDOW win AS win2"), - errorClass = "INVALID_SQL_SYNTAX.UNRESOLVED_WINDOW_REFERENCE", + condition = "INVALID_SQL_SYNTAX.UNRESOLVED_WINDOW_REFERENCE", sqlState = "42000", parameters = Map("windowName" -> "`win2`"), context = ExpectedContext( @@ -194,7 +194,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL val sqlText = "SELECT TRANSFORM(DISTINCT a) USING 'a' FROM t" checkError( exception = parseException(sqlText), - errorClass = "UNSUPPORTED_FEATURE.TRANSFORM_DISTINCT_ALL", + condition = "UNSUPPORTED_FEATURE.TRANSFORM_DISTINCT_ALL", sqlState = "0A000", context = ExpectedContext( fragment = sqlText, @@ -207,7 +207,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL "'org.apache.hadoop.hive.serde2.OpenCSVSerde' USING 'a' FROM t" checkError( exception = parseException(sqlText), - errorClass = "UNSUPPORTED_FEATURE.TRANSFORM_NON_HIVE", + condition = "UNSUPPORTED_FEATURE.TRANSFORM_NON_HIVE", sqlState = "0A000", context = ExpectedContext( fragment = sqlText, @@ -218,7 +218,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL test("INVALID_SQL_SYNTAX.TRANSFORM_WRONG_NUM_ARGS: Wrong number arguments for transform") { checkError( exception = parseException("CREATE TABLE table(col int) PARTITIONED BY (years(col,col))"), - errorClass = "INVALID_SQL_SYNTAX.TRANSFORM_WRONG_NUM_ARGS", + condition = "INVALID_SQL_SYNTAX.TRANSFORM_WRONG_NUM_ARGS", sqlState = "42000", parameters = Map( "transform" -> "`years`", @@ -233,7 +233,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL test("INVALID_SQL_SYNTAX.INVALID_TABLE_VALUED_FUNC_NAME: Invalid table value function name") { checkError( exception = parseException("SELECT * FROM db.func()"), - errorClass = "INVALID_SQL_SYNTAX.INVALID_TABLE_VALUED_FUNC_NAME", + condition = "INVALID_SQL_SYNTAX.INVALID_TABLE_VALUED_FUNC_NAME", sqlState = "42000", parameters = Map("funcName" -> "`db`.`func`"), context = ExpectedContext( @@ -243,7 +243,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL checkError( exception = parseException("SELECT * FROM ns.db.func()"), - errorClass = "INVALID_SQL_SYNTAX.INVALID_TABLE_VALUED_FUNC_NAME", + condition = "INVALID_SQL_SYNTAX.INVALID_TABLE_VALUED_FUNC_NAME", sqlState = "42000", parameters = Map("funcName" -> "`ns`.`db`.`func`"), context = ExpectedContext( @@ -256,7 +256,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL val sqlText = "SHOW sys FUNCTIONS" checkError( exception = parseException(sqlText), - errorClass = "INVALID_SQL_SYNTAX.SHOW_FUNCTIONS_INVALID_SCOPE", + condition = "INVALID_SQL_SYNTAX.SHOW_FUNCTIONS_INVALID_SCOPE", sqlState = "42000", parameters = Map("scope" -> "`sys`"), context = ExpectedContext( @@ -269,7 +269,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL val sqlText1 = "SHOW FUNCTIONS IN db f1" checkError( exception = parseException(sqlText1), - errorClass = "INVALID_SQL_SYNTAX.SHOW_FUNCTIONS_INVALID_PATTERN", + condition = "INVALID_SQL_SYNTAX.SHOW_FUNCTIONS_INVALID_PATTERN", sqlState = "42000", parameters = Map("pattern" -> "`f1`"), context = ExpectedContext( @@ -279,7 +279,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL val sqlText2 = "SHOW FUNCTIONS IN db LIKE f1" checkError( exception = parseException(sqlText2), - errorClass = "INVALID_SQL_SYNTAX.SHOW_FUNCTIONS_INVALID_PATTERN", + condition = "INVALID_SQL_SYNTAX.SHOW_FUNCTIONS_INVALID_PATTERN", sqlState = "42000", parameters = Map("pattern" -> "`f1`"), context = ExpectedContext( @@ -297,7 +297,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL checkError( exception = parseException(sqlText), - errorClass = "INVALID_SQL_SYNTAX.CREATE_ROUTINE_WITH_IF_NOT_EXISTS_AND_REPLACE", + condition = "INVALID_SQL_SYNTAX.CREATE_ROUTINE_WITH_IF_NOT_EXISTS_AND_REPLACE", sqlState = "42000", context = ExpectedContext( fragment = sqlText, @@ -314,7 +314,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL checkError( exception = parseException(sqlText), - errorClass = "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_IF_NOT_EXISTS", + condition = "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_IF_NOT_EXISTS", sqlState = "42000", context = ExpectedContext( fragment = sqlText, @@ -330,7 +330,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL checkError( exception = parseException(sqlText), - errorClass = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", + condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", sqlState = "42000", parameters = Map( "statement" -> "CREATE TEMPORARY FUNCTION", @@ -350,7 +350,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL checkError( exception = parseException(sqlText), - errorClass = "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_DATABASE", + condition = "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_DATABASE", sqlState = "42000", parameters = Map("database" -> "`db`"), context = ExpectedContext( @@ -363,7 +363,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL val sqlText = "DROP TEMPORARY FUNCTION db.func" checkError( exception = parseException(sqlText), - errorClass = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", + condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", sqlState = "42000", parameters = Map( "statement" -> "DROP TEMPORARY FUNCTION", @@ -377,7 +377,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL test("DUPLICATE_KEY: Found duplicate partition keys") { checkError( exception = parseException("INSERT OVERWRITE TABLE table PARTITION(p1='1', p1='1') SELECT 'col1', 'col2'"), - errorClass = "DUPLICATE_KEY", + condition = "DUPLICATE_KEY", sqlState = "23505", parameters = Map("keyColumn" -> "`p1`"), context = ExpectedContext( @@ -389,7 +389,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL test("DUPLICATE_KEY: in table properties") { checkError( exception = parseException("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('key1' = '1', 'key1' = '2')"), - errorClass = "DUPLICATE_KEY", + condition = "DUPLICATE_KEY", sqlState = "23505", parameters = Map("keyColumn" -> "`key1`"), context = ExpectedContext( @@ -401,24 +401,24 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL test("PARSE_EMPTY_STATEMENT: empty input") { checkError( exception = parseException(""), - errorClass = "PARSE_EMPTY_STATEMENT", + condition = "PARSE_EMPTY_STATEMENT", sqlState = Some("42617")) checkError( exception = parseException(" "), - errorClass = "PARSE_EMPTY_STATEMENT", + condition = "PARSE_EMPTY_STATEMENT", sqlState = Some("42617")) checkError( exception = parseException(" \n"), - errorClass = "PARSE_EMPTY_STATEMENT", + condition = "PARSE_EMPTY_STATEMENT", sqlState = Some("42617")) } test("PARSE_SYNTAX_ERROR: no viable input") { checkError( exception = parseException("select ((r + 1) "), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", sqlState = "42601", parameters = Map("error" -> "end of input", "hint" -> "")) } @@ -426,7 +426,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL def checkParseSyntaxError(sqlCommand: String, errorString: String, hint: String = ""): Unit = { checkError( exception = parseException(sqlCommand), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", sqlState = "42601", parameters = Map("error" -> errorString, "hint" -> hint) ) @@ -444,13 +444,13 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL test("PARSE_SYNTAX_ERROR: extraneous input") { checkError( exception = parseException("select 1 1"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", sqlState = "42601", parameters = Map("error" -> "'1'", "hint" -> ": extra input '1'")) checkError( exception = parseException("select *\nfrom r as q t"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", sqlState = "42601", parameters = Map("error" -> "'t'", "hint" -> ": extra input 't'")) } @@ -458,13 +458,13 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL test("PARSE_SYNTAX_ERROR: mismatched input") { checkError( exception = parseException("select * from r order by q from t"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", sqlState = "42601", parameters = Map("error" -> "'from'", "hint" -> "")) checkError( exception = parseException("select *\nfrom r\norder by q\nfrom t"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", sqlState = "42601", parameters = Map("error" -> "'from'", "hint" -> "")) } @@ -473,13 +473,13 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL // '' -> end of input checkError( exception = parseException("select count(*"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", sqlState = "42601", parameters = Map("error" -> "end of input", "hint" -> "")) checkError( exception = parseException("select 1 as a from"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", sqlState = "42601", parameters = Map("error" -> "end of input", "hint" -> "")) } @@ -488,19 +488,19 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL "misleading error message due to problematic antlr grammar") { checkError( exception = parseException("select * from a left join_ b on a.id = b.id"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", sqlState = "42601", parameters = Map("error" -> "'join_'", "hint" -> ": missing 'JOIN'")) checkError( exception = parseException("select * from test where test.t is like 'test'"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", sqlState = "42601", parameters = Map("error" -> "'is'", "hint" -> "")) checkError( exception = parseException("SELECT * FROM test WHERE x NOT NULL"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", sqlState = "42601", parameters = Map("error" -> "'NOT'", "hint" -> "")) } @@ -508,7 +508,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL test("INVALID_SQL_SYNTAX.EMPTY_PARTITION_VALUE: show table partition key must set value") { checkError( exception = parseException("SHOW TABLE EXTENDED IN default LIKE 'employee' PARTITION (grade)"), - errorClass = "INVALID_SQL_SYNTAX.EMPTY_PARTITION_VALUE", + condition = "INVALID_SQL_SYNTAX.EMPTY_PARTITION_VALUE", sqlState = "42000", parameters = Map("partKey" -> "`grade`"), context = ExpectedContext( @@ -522,7 +522,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL checkError( exception = parseException("CREATE TABLE my_tab(a INT, b STRING) " + "USING parquet PARTITIONED BY (bucket(32, a, 66))"), - errorClass = "INVALID_SQL_SYNTAX.INVALID_COLUMN_REFERENCE", + condition = "INVALID_SQL_SYNTAX.INVALID_COLUMN_REFERENCE", sqlState = "42000", parameters = Map( "transform" -> "`bucket`", @@ -537,7 +537,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL val sqlText = "DESCRIBE TABLE EXTENDED customer PARTITION (grade = 'A') customer.age" checkError( exception = parseException(sqlText), - errorClass = "UNSUPPORTED_FEATURE.DESC_TABLE_COLUMN_PARTITION", + condition = "UNSUPPORTED_FEATURE.DESC_TABLE_COLUMN_PARTITION", sqlState = "0A000", context = ExpectedContext( fragment = sqlText, @@ -549,7 +549,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL val sqlText = "DESCRIBE TABLE EXTENDED customer PARTITION (grade)" checkError( exception = parseException(sqlText), - errorClass = "INVALID_SQL_SYNTAX.EMPTY_PARTITION_VALUE", + condition = "INVALID_SQL_SYNTAX.EMPTY_PARTITION_VALUE", sqlState = "42000", parameters = Map("partKey" -> "`grade`"), context = ExpectedContext( @@ -562,7 +562,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL val sqlText = "CREATE NAMESPACE IF NOT EXISTS a.b.c WITH PROPERTIES ('location'='/home/user/db')" checkError( exception = parseException(sqlText), - errorClass = "UNSUPPORTED_FEATURE.SET_NAMESPACE_PROPERTY", + condition = "UNSUPPORTED_FEATURE.SET_NAMESPACE_PROPERTY", sqlState = "0A000", parameters = Map( "property" -> "location", @@ -578,7 +578,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL "USING PARQUET TBLPROPERTIES ('provider'='parquet')" checkError( exception = parseException(sqlText), - errorClass = "UNSUPPORTED_FEATURE.SET_TABLE_PROPERTY", + condition = "UNSUPPORTED_FEATURE.SET_TABLE_PROPERTY", sqlState = "0A000", parameters = Map( "property" -> "provider", @@ -593,7 +593,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL val sqlText = "set =`value`" checkError( exception = parseException(sqlText), - errorClass = "INVALID_PROPERTY_KEY", + condition = "INVALID_PROPERTY_KEY", parameters = Map("key" -> "\"\"", "value" -> "\"value\""), context = ExpectedContext( fragment = sqlText, @@ -604,7 +604,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL test("INVALID_PROPERTY_VALUE: invalid property value for set quoted configuration") { checkError( exception = parseException("set `key`=1;2;;"), - errorClass = "INVALID_PROPERTY_VALUE", + condition = "INVALID_PROPERTY_VALUE", parameters = Map("value" -> "\"1;2;;\"", "key" -> "\"key\""), context = ExpectedContext( fragment = "set `key`=1;2", @@ -617,7 +617,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL "WITH DBPROPERTIES('a'='a', 'b'='b', 'c'='c')" checkError( exception = parseException(sqlText), - errorClass = "UNSUPPORTED_FEATURE.SET_PROPERTIES_AND_DBPROPERTIES", + condition = "UNSUPPORTED_FEATURE.SET_PROPERTIES_AND_DBPROPERTIES", sqlState = "0A000", context = ExpectedContext( fragment = sqlText, @@ -629,28 +629,28 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL // Cast simple array without specifying element type checkError( exception = parseException("SELECT CAST(array(1,2,3) AS ARRAY)"), - errorClass = "INCOMPLETE_TYPE_DEFINITION.ARRAY", + condition = "INCOMPLETE_TYPE_DEFINITION.ARRAY", sqlState = "42K01", parameters = Map("elementType" -> ""), context = ExpectedContext(fragment = "ARRAY", start = 28, stop = 32)) // Cast array of array without specifying element type for inner array checkError( exception = parseException("SELECT CAST(array(array(3)) AS ARRAY)"), - errorClass = "INCOMPLETE_TYPE_DEFINITION.ARRAY", + condition = "INCOMPLETE_TYPE_DEFINITION.ARRAY", sqlState = "42K01", parameters = Map("elementType" -> ""), context = ExpectedContext(fragment = "ARRAY", start = 37, stop = 41)) // Create column of array type without specifying element type checkError( exception = parseException("CREATE TABLE tbl_120691 (col1 ARRAY)"), - errorClass = "INCOMPLETE_TYPE_DEFINITION.ARRAY", + condition = "INCOMPLETE_TYPE_DEFINITION.ARRAY", sqlState = "42K01", parameters = Map("elementType" -> ""), context = ExpectedContext(fragment = "ARRAY", start = 30, stop = 34)) // Create column of array type without specifying element type in lowercase checkError( exception = parseException("CREATE TABLE tbl_120691 (col1 array)"), - errorClass = "INCOMPLETE_TYPE_DEFINITION.ARRAY", + condition = "INCOMPLETE_TYPE_DEFINITION.ARRAY", sqlState = "42K01", parameters = Map("elementType" -> ""), context = ExpectedContext(fragment = "array", start = 30, stop = 34)) @@ -660,31 +660,31 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL // Cast simple struct without specifying field type checkError( exception = parseException("SELECT CAST(struct(1,2,3) AS STRUCT)"), - errorClass = "INCOMPLETE_TYPE_DEFINITION.STRUCT", + condition = "INCOMPLETE_TYPE_DEFINITION.STRUCT", sqlState = "42K01", context = ExpectedContext(fragment = "STRUCT", start = 29, stop = 34)) // Cast array of struct without specifying field type in struct checkError( exception = parseException("SELECT CAST(array(struct(1,2)) AS ARRAY)"), - errorClass = "INCOMPLETE_TYPE_DEFINITION.STRUCT", + condition = "INCOMPLETE_TYPE_DEFINITION.STRUCT", sqlState = "42K01", context = ExpectedContext(fragment = "STRUCT", start = 40, stop = 45)) // Create column of struct type without specifying field type checkError( exception = parseException("CREATE TABLE tbl_120691 (col1 STRUCT)"), - errorClass = "INCOMPLETE_TYPE_DEFINITION.STRUCT", + condition = "INCOMPLETE_TYPE_DEFINITION.STRUCT", sqlState = "42K01", context = ExpectedContext(fragment = "STRUCT", start = 30, stop = 35)) // Invalid syntax `STRUCT` without field name checkError( exception = parseException("SELECT CAST(struct(1,2,3) AS STRUCT)"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", sqlState = "42601", parameters = Map("error" -> "'<'", "hint" -> ": missing ')'")) // Create column of struct type without specifying field type in lowercase checkError( exception = parseException("CREATE TABLE tbl_120691 (col1 struct)"), - errorClass = "INCOMPLETE_TYPE_DEFINITION.STRUCT", + condition = "INCOMPLETE_TYPE_DEFINITION.STRUCT", sqlState = "42K01", context = ExpectedContext(fragment = "struct", start = 30, stop = 35)) } @@ -693,25 +693,25 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL // Cast simple map without specifying element type checkError( exception = parseException("SELECT CAST(map(1,'2') AS MAP)"), - errorClass = "INCOMPLETE_TYPE_DEFINITION.MAP", + condition = "INCOMPLETE_TYPE_DEFINITION.MAP", sqlState = "42K01", context = ExpectedContext(fragment = "MAP", start = 26, stop = 28)) // Create column of map type without specifying key/value types checkError( exception = parseException("CREATE TABLE tbl_120691 (col1 MAP)"), - errorClass = "INCOMPLETE_TYPE_DEFINITION.MAP", + condition = "INCOMPLETE_TYPE_DEFINITION.MAP", sqlState = "42K01", context = ExpectedContext(fragment = "MAP", start = 30, stop = 32)) // Invalid syntax `MAP` with only key type checkError( exception = parseException("SELECT CAST(map('1',2) AS MAP)"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", sqlState = "42601", parameters = Map("error" -> "'<'", "hint" -> ": missing ')'")) // Create column of map type without specifying key/value types in lowercase checkError( exception = parseException("SELECT CAST(map('1',2) AS map)"), - errorClass = "INCOMPLETE_TYPE_DEFINITION.MAP", + condition = "INCOMPLETE_TYPE_DEFINITION.MAP", sqlState = "42K01", context = ExpectedContext(fragment = "map", start = 26, stop = 28)) } @@ -719,7 +719,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL test("INVALID_ESC: Escape string must contain only one character") { checkError( exception = parseException("select * from test where test.t like 'pattern%' escape '##'"), - errorClass = "INVALID_ESC", + condition = "INVALID_ESC", parameters = Map("invalidEscape" -> "'##'"), context = ExpectedContext( fragment = "like 'pattern%' escape '##'", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala index dc72b4a092aef..9ed4f1a006b2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala @@ -317,7 +317,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper { import spark.implicits._ spark.conf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "1KB") spark.conf.set(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key, "10KB") - spark.conf.set(SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR, 2.0) + spark.conf.set(SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key, "2.0") val df00 = spark.range(0, 1000, 2) .selectExpr("id as key", "id as value") .union(Seq.fill(100000)((600, 600)).toDF("key", "value")) @@ -345,7 +345,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper { import spark.implicits._ spark.conf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") spark.conf.set(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key, "100B") - spark.conf.set(SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR, 2.0) + spark.conf.set(SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key, "2.0") val df00 = spark.range(0, 10, 2) .selectExpr("id as key", "id as value") .union(Seq.fill(1000)((600, 600)).toDF("key", "value")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExecuteImmediateEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExecuteImmediateEndToEndSuite.scala index cae22eda32f80..62a32da22d957 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExecuteImmediateEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExecuteImmediateEndToEndSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{QueryTest} -import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.test.SharedSparkSession class ExecuteImmediateEndToEndSuite extends QueryTest with SharedSparkSession { @@ -37,30 +36,4 @@ class ExecuteImmediateEndToEndSuite extends QueryTest with SharedSparkSession { spark.sql("DROP TEMPORARY VARIABLE IF EXISTS parm;") } } - - test("EXEC IMMEDIATE STACK OVERFLOW") { - try { - spark.sql("DECLARE parm = 1;") - val query = (1 to 20000).map(x => "SELECT 1 as a").mkString(" UNION ALL ") - Seq( - s"EXECUTE IMMEDIATE '$query'", - s"EXECUTE IMMEDIATE '$query' INTO parm").foreach { q => - val e = intercept[ParseException] { - spark.sql(q) - } - - checkError( - exception = e, - errorClass = "FAILED_TO_PARSE_TOO_COMPLEX", - parameters = Map(), - context = ExpectedContext( - query, - start = 0, - stop = query.length - 1) - ) - } - } finally { - spark.sql("DROP TEMPORARY VARIABLE IF EXISTS parm;") - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index 94d33731b6de5..059a4c9b83763 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -228,7 +228,7 @@ class SQLExecutionSuite extends SparkFunSuite with SQLConfHelper { spark.range(1).collect() spark.sparkContext.listenerBus.waitUntilEmpty() - assert(jobTags.contains(jobTag)) + assert(jobTags.get.contains(jobTag)) assert(sqlJobTags.contains(jobTag)) } finally { spark.sparkContext.removeJobTag(jobTag) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index f54a4f4606061..b26cdfaeb756a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -85,7 +85,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql("CREATE VIEW jtv1 AS SELECT * FROM temp_jtv1 WHERE id < 6") }, - errorClass = "INVALID_TEMP_OBJ_REFERENCE", + condition = "INVALID_TEMP_OBJ_REFERENCE", parameters = Map( "obj" -> "VIEW", "objName" -> s"`$SESSION_CATALOG_NAME`.`default`.`jtv1`", @@ -97,7 +97,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"CREATE VIEW jtv1 AS SELECT * FROM $globalTempDB.global_temp_jtv1 WHERE id < 6") }, - errorClass = "INVALID_TEMP_OBJ_REFERENCE", + condition = "INVALID_TEMP_OBJ_REFERENCE", parameters = Map( "obj" -> "VIEW", "objName" -> s"`$SESSION_CATALOG_NAME`.`default`.`jtv1`", @@ -115,7 +115,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql("CREATE OR REPLACE VIEW tab1 AS SELECT * FROM jt") }, - errorClass = "EXPECT_VIEW_NOT_TABLE.NO_ALTERNATIVE", + condition = "EXPECT_VIEW_NOT_TABLE.NO_ALTERNATIVE", parameters = Map( "tableName" -> s"`$SESSION_CATALOG_NAME`.`default`.`tab1`", "operation" -> "CREATE OR REPLACE VIEW") @@ -124,7 +124,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql("CREATE VIEW tab1 AS SELECT * FROM jt") }, - errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map( "relationName" -> s"`$SESSION_CATALOG_NAME`.`default`.`tab1`") ) @@ -132,7 +132,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql("ALTER VIEW tab1 AS SELECT * FROM jt") }, - errorClass = "EXPECT_VIEW_NOT_TABLE.NO_ALTERNATIVE", + condition = "EXPECT_VIEW_NOT_TABLE.NO_ALTERNATIVE", parameters = Map( "tableName" -> s"`$SESSION_CATALOG_NAME`.`default`.`tab1`", "operation" -> "ALTER VIEW ... AS" @@ -161,7 +161,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"ALTER VIEW $viewName SET TBLPROPERTIES ('p' = 'an')") }, - errorClass = "EXPECT_PERMANENT_VIEW_NOT_TEMP", + condition = "EXPECT_PERMANENT_VIEW_NOT_TEMP", parameters = Map( "viewName" -> s"`$viewName`", "operation" -> "ALTER VIEW ... SET TBLPROPERTIES" @@ -176,7 +176,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"ALTER VIEW $viewName UNSET TBLPROPERTIES ('p')") }, - errorClass = "EXPECT_PERMANENT_VIEW_NOT_TEMP", + condition = "EXPECT_PERMANENT_VIEW_NOT_TEMP", parameters = Map( "viewName" -> s"`$viewName`", "operation" -> "ALTER VIEW ... UNSET TBLPROPERTIES" @@ -198,7 +198,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $viewName SET SERDE 'whatever'") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", + condition = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", parameters = Map( "viewName" -> s"`$viewName`", "operation" -> "ALTER TABLE ... SET [SERDE|SERDEPROPERTIES]" @@ -209,7 +209,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $viewName PARTITION (a=1, b=2) SET SERDE 'whatever'") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", + condition = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", parameters = Map( "viewName" -> s"`$viewName`", "operation" -> "ALTER TABLE ... SET [SERDE|SERDEPROPERTIES]" @@ -220,7 +220,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $viewName SET SERDEPROPERTIES ('p' = 'an')") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", + condition = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", parameters = Map( "viewName" -> s"`$viewName`", "operation" -> "ALTER TABLE ... SET [SERDE|SERDEPROPERTIES]" @@ -231,7 +231,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $viewName PARTITION (a='4') RENAME TO PARTITION (a='5')") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> s"`$viewName`", "operation" -> "ALTER TABLE ... RENAME TO PARTITION" @@ -242,7 +242,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $viewName RECOVER PARTITIONS") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> s"`$viewName`", "operation" -> "ALTER TABLE ... RECOVER PARTITIONS" @@ -253,7 +253,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $viewName SET LOCATION '/path/to/your/lovely/heart'") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> s"`$viewName`", "operation" -> "ALTER TABLE ... SET LOCATION ..." @@ -264,7 +264,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $viewName PARTITION (a='4') SET LOCATION '/path/to/home'") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> s"`$viewName`", "operation" -> "ALTER TABLE ... SET LOCATION ..." @@ -275,7 +275,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $viewName ADD IF NOT EXISTS PARTITION (a='4', b='8')") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> s"`$viewName`", "operation" -> "ALTER TABLE ... ADD PARTITION ..." @@ -286,7 +286,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $viewName DROP PARTITION (a='4', b='8')") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> s"`$viewName`", "operation" -> "ALTER TABLE ... DROP PARTITION ..." @@ -297,7 +297,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $viewName SET TBLPROPERTIES ('p' = 'an')") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", + condition = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", parameters = Map( "viewName" -> s"`$viewName`", "operation" -> "ALTER TABLE ... SET TBLPROPERTIES" @@ -308,7 +308,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $viewName UNSET TBLPROPERTIES ('p')") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", + condition = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", parameters = Map( "viewName" -> s"`$viewName`", "operation" -> "ALTER TABLE ... UNSET TBLPROPERTIES" @@ -327,7 +327,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"INSERT INTO TABLE $viewName SELECT 1") }, - errorClass = "UNSUPPORTED_INSERT.RDD_BASED", + condition = "UNSUPPORTED_INSERT.RDD_BASED", parameters = Map.empty ) @@ -338,7 +338,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> s"`$viewName`", "operation" -> "LOAD DATA" @@ -353,7 +353,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"SHOW CREATE TABLE $viewName") }, - errorClass = "EXPECT_PERMANENT_VIEW_NOT_TEMP", + condition = "EXPECT_PERMANENT_VIEW_NOT_TEMP", parameters = Map( "viewName" -> s"`$viewName`", "operation" -> "SHOW CREATE TABLE" @@ -368,7 +368,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") }, - errorClass = "EXPECT_PERMANENT_VIEW_NOT_TEMP", + condition = "EXPECT_PERMANENT_VIEW_NOT_TEMP", parameters = Map( "viewName" -> s"`$viewName`", "operation" -> "ANALYZE TABLE" @@ -383,18 +383,19 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") }, - errorClass = "UNSUPPORTED_FEATURE.ANALYZE_UNCACHED_TEMP_VIEW", + condition = "UNSUPPORTED_FEATURE.ANALYZE_UNCACHED_TEMP_VIEW", parameters = Map("viewName" -> s"`$viewName`") ) } } - private def assertAnalysisErrorClass(query: String, - errorClass: String, + private def assertAnalysisErrorCondition( + query: String, + condition: String, parameters: Map[String, String], context: ExpectedContext): Unit = { val e = intercept[AnalysisException](sql(query)) - checkError(e, errorClass = errorClass, parameters = parameters, context = context) + checkError(e, condition = condition, parameters = parameters, context = context) } test("error handling: insert/load table commands against a view") { @@ -405,7 +406,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"INSERT INTO TABLE $viewName SELECT 1") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> s"`$SESSION_CATALOG_NAME`.`default`.`testview`", "operation" -> "INSERT" @@ -420,7 +421,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> s"`$SESSION_CATALOG_NAME`.`default`.`testview`", "operation" -> "LOAD DATA"), @@ -488,7 +489,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { test("error handling: fail if the temp view sql itself is invalid") { // A database that does not exist - assertAnalysisErrorClass( + assertAnalysisErrorCondition( "CREATE OR REPLACE TEMPORARY VIEW myabcdview AS SELECT * FROM db_not_exist234.jt", "TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> "`db_not_exist234`.`jt`"), @@ -513,7 +514,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[ParseException] { sql(sqlText) }, - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> "TBLPROPERTIES can't coexist with CREATE TEMPORARY VIEW"), context = ExpectedContext(sqlText, 0, 77)) } @@ -867,7 +868,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql("CREATE VIEW testView2(x, y, z) AS SELECT * FROM tab1") }, - errorClass = "CREATE_VIEW_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + condition = "CREATE_VIEW_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", parameters = Map( "viewName" -> s"`$SESSION_CATALOG_NAME`.`default`.`testView2`", "viewColumns" -> "`x`, `y`, `z`", @@ -884,7 +885,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { .write.mode(SaveMode.Overwrite).saveAsTable("tab1") checkError( exception = intercept[AnalysisException](sql("SELECT * FROM testView")), - errorClass = "INCOMPATIBLE_VIEW_SCHEMA_CHANGE", + condition = "INCOMPATIBLE_VIEW_SCHEMA_CHANGE", parameters = Map( "viewName" -> s"`$SESSION_CATALOG_NAME`.`default`.`testview`", "actualCols" -> "[]", @@ -914,7 +915,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { df2.write.format("json").mode(SaveMode.Overwrite).saveAsTable("tab1") checkError( exception = intercept[AnalysisException](sql("SELECT * FROM testView")), - errorClass = "CANNOT_UP_CAST_DATATYPE", + condition = "CANNOT_UP_CAST_DATATYPE", parameters = Map( "expression" -> s"$SESSION_CATALOG_NAME.default.tab1.id", "sourceType" -> "\"DOUBLE\"", @@ -930,7 +931,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { df3.write.format("json").mode(SaveMode.Overwrite).saveAsTable("tab1") checkError( exception = intercept[AnalysisException](sql("SELECT * FROM testView")), - errorClass = "CANNOT_UP_CAST_DATATYPE", + condition = "CANNOT_UP_CAST_DATATYPE", parameters = Map( "expression" -> s"$SESSION_CATALOG_NAME.default.tab1.id1", "sourceType" -> "\"ARRAY\"", @@ -956,7 +957,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql("ALTER VIEW view1 AS SELECT * FROM view2") }, - errorClass = "RECURSIVE_VIEW", + condition = "RECURSIVE_VIEW", parameters = Map( "viewIdent" -> s"`$SESSION_CATALOG_NAME`.`default`.`view1`", "newPath" -> (s"`$SESSION_CATALOG_NAME`.`default`.`view1` -> " + @@ -970,7 +971,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql("ALTER VIEW view1 AS SELECT * FROM view3 JOIN view2") }, - errorClass = "RECURSIVE_VIEW", + condition = "RECURSIVE_VIEW", parameters = Map( "viewIdent" -> s"`$SESSION_CATALOG_NAME`.`default`.`view1`", "newPath" -> (s"`$SESSION_CATALOG_NAME`.`default`.`view1` -> " + @@ -985,7 +986,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql("CREATE OR REPLACE VIEW view1 AS SELECT * FROM view2") }, - errorClass = "RECURSIVE_VIEW", + condition = "RECURSIVE_VIEW", parameters = Map( "viewIdent" -> s"`$SESSION_CATALOG_NAME`.`default`.`view1`", "newPath" -> (s"`$SESSION_CATALOG_NAME`.`default`.`view1` -> " + @@ -999,7 +1000,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql("ALTER VIEW view1 AS SELECT * FROM jt WHERE EXISTS (SELECT 1 FROM view2)") }, - errorClass = "RECURSIVE_VIEW", + condition = "RECURSIVE_VIEW", parameters = Map( "viewIdent" -> s"`$SESSION_CATALOG_NAME`.`default`.`view1`", "newPath" -> (s"`$SESSION_CATALOG_NAME`.`default`.`view1` -> " + @@ -1071,7 +1072,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[SparkException] { sql("SELECT * FROM v1").collect() }, - errorClass = "FAILED_READ_FILE.FILE_NOT_EXIST", + condition = "FAILED_READ_FILE.FILE_NOT_EXIST", parameters = Map("path" -> ".*") ) } @@ -1091,7 +1092,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[SparkException] { sql("SELECT * FROM v1").collect() }, - errorClass = "FAILED_READ_FILE.FILE_NOT_EXIST", + condition = "FAILED_READ_FILE.FILE_NOT_EXIST", parameters = Map("path" -> ".*") ) } @@ -1157,7 +1158,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { sql("SELECT * FROM v1") } checkError(e, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map( "objectName" -> "`C1`", @@ -1178,7 +1179,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { sql("SELECT * FROM v3") } checkError(e, - errorClass = "MISSING_AGGREGATION", + condition = "MISSING_AGGREGATION", parameters = Map( "expression" -> "\"c1\"", "expressionAnyValue" -> "\"any_value(c1)\"")) @@ -1188,7 +1189,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { sql("SELECT * FROM v4") } checkError(e, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map( "objectName" -> "`a`", @@ -1206,7 +1207,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[SparkArithmeticException] { sql("SELECT * FROM v5").collect() }, - errorClass = "DIVIDE_BY_ZERO", + condition = "DIVIDE_BY_ZERO", parameters = Map("config" -> "\"spark.sql.ansi.enabled\""), context = ExpectedContext( objectType = "VIEW", @@ -1225,7 +1226,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { exception = intercept[SparkArithmeticException] { sql("SELECT * FROM v1").collect() }, - errorClass = "DIVIDE_BY_ZERO", + condition = "DIVIDE_BY_ZERO", parameters = Map("config" -> "\"spark.sql.ansi.enabled\""), context = ExpectedContext( objectType = "VIEW", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala index aa6295fa8815f..0faace9227dd1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala @@ -182,7 +182,7 @@ abstract class SQLViewTestSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { createView("v1", s"SELECT * FROM $viewName2", replace = true) }, - errorClass = "RECURSIVE_VIEW", + condition = "RECURSIVE_VIEW", parameters = Map( "viewIdent" -> tableIdentifier("v1").quotedString, "newPath" -> (s"${tableIdentifier("v1").quotedString} " + @@ -203,7 +203,7 @@ abstract class SQLViewTestSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"ALTER VIEW $viewName1 AS SELECT * FROM $viewName2") }, - errorClass = "RECURSIVE_VIEW", + condition = "RECURSIVE_VIEW", parameters = Map( "viewIdent" -> tableIdentifier("v1").quotedString, "newPath" -> (s"${tableIdentifier("v1").quotedString} " + @@ -227,7 +227,7 @@ abstract class SQLViewTestSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(s"SELECT * FROM ${viewNames.last}") }, - errorClass = "VIEW_EXCEED_MAX_NESTED_DEPTH", + condition = "VIEW_EXCEED_MAX_NESTED_DEPTH", parameters = Map( "viewName" -> tableIdentifier("view0").quotedString, "maxNestedDepth" -> "10"), @@ -363,7 +363,7 @@ abstract class SQLViewTestSuite extends QueryTest with SQLTestUtils { sql("CREATE TABLE t(s STRUCT) USING json") checkError( exception = intercept[AnalysisException](spark.table(viewName)), - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`i`", "fields" -> "`j`"), context = ExpectedContext( fragment = "s.i", @@ -399,7 +399,7 @@ abstract class SQLViewTestSuite extends QueryTest with SQLTestUtils { } else { checkErrorMatchPVals( exception = intercept[AnalysisException](spark.table(viewName).collect()), - errorClass = "INCOMPATIBLE_VIEW_SCHEMA_CHANGE", + condition = "INCOMPATIBLE_VIEW_SCHEMA_CHANGE", parameters = Map( "viewName" -> ".*test[v|V]iew.*", "actualCols" -> "\\[COL,col,col\\]", @@ -436,7 +436,7 @@ abstract class SQLViewTestSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException]( sql(s"SELECT * FROM $viewName VERSION AS OF 1").collect() ), - errorClass = "UNSUPPORTED_FEATURE.TIME_TRAVEL", + condition = "UNSUPPORTED_FEATURE.TIME_TRAVEL", parameters = Map("relationId" -> ".*test[v|V]iew.*") ) @@ -444,7 +444,7 @@ abstract class SQLViewTestSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException]( sql(s"SELECT * FROM $viewName TIMESTAMP AS OF '2000-10-10'").collect() ), - errorClass = "UNSUPPORTED_FEATURE.TIME_TRAVEL", + condition = "UNSUPPORTED_FEATURE.TIME_TRAVEL", parameters = Map("relationId" -> ".*test[v|V]iew.*") ) } @@ -489,7 +489,7 @@ abstract class TempViewTestSuite extends SQLViewTestSuite { exception = intercept[AnalysisException] { sql(s"SHOW CREATE TABLE ${formattedViewName(viewName)}") }, - errorClass = "EXPECT_PERMANENT_VIEW_NOT_TEMP", + condition = "EXPECT_PERMANENT_VIEW_NOT_TEMP", parameters = Map( "viewName" -> toSQLId(tableIdentifier(viewName).nameParts), "operation" -> "SHOW CREATE TABLE"), @@ -577,7 +577,7 @@ class PersistedViewTestSuite extends SQLViewTestSuite with SharedSparkSession { exception = intercept[AnalysisException] { sql("CREATE VIEW v AS SELECT count(*) FROM VALUES (1), (2), (3) t(a)") }, - errorClass = "CREATE_PERMANENT_VIEW_WITHOUT_ALIAS", + condition = "CREATE_PERMANENT_VIEW_WITHOUT_ALIAS", parameters = Map("name" -> tableIdentifier("v").quotedString, "attr" -> "\"count(1)\"") ) sql("CREATE VIEW v AS SELECT count(*) AS cnt FROM VALUES (1), (2), (3) t(a)") @@ -591,7 +591,7 @@ class PersistedViewTestSuite extends SQLViewTestSuite with SharedSparkSession { exception = intercept[AnalysisException] { sql("CREATE VIEW v AS SELECT * FROM (SELECT a + b FROM VALUES (1, 2) t(a, b))") }, - errorClass = "CREATE_PERMANENT_VIEW_WITHOUT_ALIAS", + condition = "CREATE_PERMANENT_VIEW_WITHOUT_ALIAS", parameters = Map("name" -> tableIdentifier("v").quotedString, "attr" -> "\"(a + b)\"") ) sql("CREATE VIEW v AS SELECT * FROM (SELECT a + b AS col FROM VALUES (1, 2) t(a, b))") @@ -606,7 +606,7 @@ class PersistedViewTestSuite extends SQLViewTestSuite with SharedSparkSession { exception = intercept[AnalysisException] { sql("ALTER VIEW v AS SELECT count(*) FROM VALUES (1), (2), (3) t(a)") }, - errorClass = "CREATE_PERMANENT_VIEW_WITHOUT_ALIAS", + condition = "CREATE_PERMANENT_VIEW_WITHOUT_ALIAS", parameters = Map("name" -> tableIdentifier("v").quotedString, "attr" -> "\"count(1)\"") ) } @@ -639,7 +639,7 @@ class PersistedViewTestSuite extends SQLViewTestSuite with SharedSparkSession { val unquotedViewName = tableIdentifier("test_view").unquotedString checkError( exception = e, - errorClass = "INCOMPATIBLE_VIEW_SCHEMA_CHANGE", + condition = "INCOMPATIBLE_VIEW_SCHEMA_CHANGE", parameters = Map( "viewName" -> tableIdentifier("test_view").quotedString, "suggestion" -> s"CREATE OR REPLACE VIEW $unquotedViewName AS SELECT * FROM t", @@ -665,7 +665,7 @@ class PersistedViewTestSuite extends SQLViewTestSuite with SharedSparkSession { exception = intercept[AnalysisException] { sql("ALTER VIEW v1 AS SELECT * FROM v2") }, - errorClass = "INVALID_TEMP_OBJ_REFERENCE", + condition = "INVALID_TEMP_OBJ_REFERENCE", parameters = Map( "obj" -> "VIEW", "objName" -> tableIdentifier("v1").quotedString, @@ -679,7 +679,7 @@ class PersistedViewTestSuite extends SQLViewTestSuite with SharedSparkSession { exception = intercept[AnalysisException] { sql(s"ALTER VIEW v1 AS SELECT $tempFunctionName(id) from t") }, - errorClass = "INVALID_TEMP_OBJ_REFERENCE", + condition = "INVALID_TEMP_OBJ_REFERENCE", parameters = Map( "obj" -> "VIEW", "objName" -> tableIdentifier("v1").quotedString, @@ -724,7 +724,7 @@ class PersistedViewTestSuite extends SQLViewTestSuite with SharedSparkSession { exception = intercept[AnalysisException] { sql("SELECT * FROM v") }, - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'DROP'", "hint" -> ""), context = ExpectedContext( objectType = "VIEW", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala index 966f4e747122a..8dc07e2df99fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala @@ -112,7 +112,7 @@ class SparkPlanSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkException] { planner.plan(deduplicate) }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map( "message" -> ("Deduplicate operator for non streaming data source should have been " + "replaced by aggregate in the optimizer"))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala index 52378f7370930..dbb8e9697089e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala @@ -54,7 +54,7 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with |FROM v""".stripMargin checkError( exception = intercept[ParseException](sql(sqlText)), - errorClass = "UNSUPPORTED_FEATURE.TRANSFORM_NON_HIVE", + condition = "UNSUPPORTED_FEATURE.TRANSFORM_NON_HIVE", parameters = Map.empty, context = ExpectedContext(sqlText, 0, 185)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index f60df77b7e9bd..ab949c5a21e44 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -26,10 +26,11 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, Un import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Concat, GreaterThan, Literal, NullsFirst, SortOrder, UnresolvedWindowExpression, UnspecifiedFrame, WindowSpecDefinition, WindowSpecReference} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.TreePattern.{FILTER, LOCAL_RELATION, PROJECT, UNRESOLVED_RELATION} import org.apache.spark.sql.connector.catalog.TableCatalog import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTempViewUsing, RefreshResource} -import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StringType import org.apache.spark.util.ArrayImplicits._ @@ -83,12 +84,12 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { checkError( exception = parseException("SET k=`v` /*"), - errorClass = "UNCLOSED_BRACKETED_COMMENT", + condition = "UNCLOSED_BRACKETED_COMMENT", parameters = Map.empty) checkError( exception = parseException("SET `k`=`v` /*"), - errorClass = "UNCLOSED_BRACKETED_COMMENT", + condition = "UNCLOSED_BRACKETED_COMMENT", parameters = Map.empty) } @@ -120,7 +121,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql1 = "SET spark.sql.key value" checkError( exception = parseException(sql1), - errorClass = "INVALID_SET_SYNTAX", + condition = "INVALID_SET_SYNTAX", parameters = Map.empty, context = ExpectedContext( fragment = sql1, @@ -130,7 +131,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql2 = "SET spark.sql.key 'value'" checkError( exception = parseException(sql2), - errorClass = "INVALID_SET_SYNTAX", + condition = "INVALID_SET_SYNTAX", parameters = Map.empty, context = ExpectedContext( fragment = sql2, @@ -140,7 +141,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql3 = "SET spark.sql.key \"value\" " checkError( exception = parseException(sql3), - errorClass = "INVALID_SET_SYNTAX", + condition = "INVALID_SET_SYNTAX", parameters = Map.empty, context = ExpectedContext( fragment = "SET spark.sql.key \"value\"", @@ -150,7 +151,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql4 = "SET spark.sql.key value1 value2" checkError( exception = parseException(sql4), - errorClass = "INVALID_SET_SYNTAX", + condition = "INVALID_SET_SYNTAX", parameters = Map.empty, context = ExpectedContext( fragment = sql4, @@ -160,7 +161,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql5 = "SET spark. sql.key=value" checkError( exception = parseException(sql5), - errorClass = "INVALID_SET_SYNTAX", + condition = "INVALID_SET_SYNTAX", parameters = Map.empty, context = ExpectedContext( fragment = sql5, @@ -170,7 +171,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql6 = "SET spark :sql:key=value" checkError( exception = parseException(sql6), - errorClass = "INVALID_SET_SYNTAX", + condition = "INVALID_SET_SYNTAX", parameters = Map.empty, context = ExpectedContext( fragment = sql6, @@ -180,7 +181,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql7 = "SET spark . sql.key=value" checkError( exception = parseException(sql7), - errorClass = "INVALID_SET_SYNTAX", + condition = "INVALID_SET_SYNTAX", parameters = Map.empty, context = ExpectedContext( fragment = sql7, @@ -190,7 +191,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql8 = "SET spark.sql. key=value" checkError( exception = parseException(sql8), - errorClass = "INVALID_SET_SYNTAX", + condition = "INVALID_SET_SYNTAX", parameters = Map.empty, context = ExpectedContext( fragment = sql8, @@ -200,7 +201,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql9 = "SET spark.sql :key=value" checkError( exception = parseException(sql9), - errorClass = "INVALID_SET_SYNTAX", + condition = "INVALID_SET_SYNTAX", parameters = Map.empty, context = ExpectedContext( fragment = sql9, @@ -210,7 +211,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql10 = "SET spark.sql . key=value" checkError( exception = parseException(sql10), - errorClass = "INVALID_SET_SYNTAX", + condition = "INVALID_SET_SYNTAX", parameters = Map.empty, context = ExpectedContext( fragment = sql10, @@ -220,7 +221,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql11 = "SET =" checkError( exception = parseException(sql11), - errorClass = "INVALID_SET_SYNTAX", + condition = "INVALID_SET_SYNTAX", parameters = Map.empty, context = ExpectedContext( fragment = sql11, @@ -230,7 +231,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql12 = "SET =value" checkError( exception = parseException(sql12), - errorClass = "INVALID_SET_SYNTAX", + condition = "INVALID_SET_SYNTAX", parameters = Map.empty, context = ExpectedContext( fragment = sql12, @@ -251,7 +252,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql1 = "RESET spark.sql.key1 key2" checkError( exception = parseException(sql1), - errorClass = "_LEGACY_ERROR_TEMP_0043", + condition = "_LEGACY_ERROR_TEMP_0043", parameters = Map.empty, context = ExpectedContext( fragment = sql1, @@ -261,7 +262,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql2 = "RESET spark. sql.key1 key2" checkError( exception = parseException(sql2), - errorClass = "_LEGACY_ERROR_TEMP_0043", + condition = "_LEGACY_ERROR_TEMP_0043", parameters = Map.empty, context = ExpectedContext( fragment = sql2, @@ -271,7 +272,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql3 = "RESET spark.sql.key1 key2 key3" checkError( exception = parseException(sql3), - errorClass = "_LEGACY_ERROR_TEMP_0043", + condition = "_LEGACY_ERROR_TEMP_0043", parameters = Map.empty, context = ExpectedContext( fragment = sql3, @@ -281,7 +282,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql4 = "RESET spark: sql:key" checkError( exception = parseException(sql4), - errorClass = "_LEGACY_ERROR_TEMP_0043", + condition = "_LEGACY_ERROR_TEMP_0043", parameters = Map.empty, context = ExpectedContext( fragment = sql4, @@ -291,7 +292,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql5 = "RESET spark .sql.key" checkError( exception = parseException(sql5), - errorClass = "_LEGACY_ERROR_TEMP_0043", + condition = "_LEGACY_ERROR_TEMP_0043", parameters = Map.empty, context = ExpectedContext( fragment = sql5, @@ -301,7 +302,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql6 = "RESET spark : sql:key" checkError( exception = parseException(sql6), - errorClass = "_LEGACY_ERROR_TEMP_0043", + condition = "_LEGACY_ERROR_TEMP_0043", parameters = Map.empty, context = ExpectedContext( fragment = sql6, @@ -311,7 +312,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql7 = "RESET spark.sql: key" checkError( exception = parseException(sql7), - errorClass = "_LEGACY_ERROR_TEMP_0043", + condition = "_LEGACY_ERROR_TEMP_0043", parameters = Map.empty, context = ExpectedContext( fragment = sql7, @@ -321,7 +322,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql8 = "RESET spark.sql .key" checkError( exception = parseException(sql8), - errorClass = "_LEGACY_ERROR_TEMP_0043", + condition = "_LEGACY_ERROR_TEMP_0043", parameters = Map.empty, context = ExpectedContext( fragment = sql8, @@ -331,7 +332,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql9 = "RESET spark.sql : key" checkError( exception = parseException(sql9), - errorClass = "_LEGACY_ERROR_TEMP_0043", + condition = "_LEGACY_ERROR_TEMP_0043", parameters = Map.empty, context = ExpectedContext( fragment = sql9, @@ -354,7 +355,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql1 = "SET a=1; SELECT 1" checkError( exception = parseException(sql1), - errorClass = "INVALID_SET_SYNTAX", + condition = "INVALID_SET_SYNTAX", parameters = Map.empty, context = ExpectedContext( fragment = sql1, @@ -364,7 +365,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql2 = "SET a=1;2;;" checkError( exception = parseException(sql2), - errorClass = "INVALID_SET_SYNTAX", + condition = "INVALID_SET_SYNTAX", parameters = Map.empty, context = ExpectedContext( fragment = "SET a=1;2", @@ -374,7 +375,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql3 = "SET a b=`1;;`" checkError( exception = parseException(sql3), - errorClass = "INVALID_PROPERTY_KEY", + condition = "INVALID_PROPERTY_KEY", parameters = Map("key" -> "\"a b\"", "value" -> "\"1;;\""), context = ExpectedContext( fragment = sql3, @@ -384,7 +385,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql4 = "SET `a`=1;2;;" checkError( exception = parseException(sql4), - errorClass = "INVALID_PROPERTY_VALUE", + condition = "INVALID_PROPERTY_VALUE", parameters = Map("value" -> "\"1;2;;\"", "key" -> "\"a\""), context = ExpectedContext( fragment = "SET `a`=1;2", @@ -407,7 +408,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql1 = "REFRESH a b" checkError( exception = parseException(sql1), - errorClass = "_LEGACY_ERROR_TEMP_0064", + condition = "_LEGACY_ERROR_TEMP_0064", parameters = Map("msg" -> errMsg1), context = ExpectedContext( fragment = sql1, @@ -417,7 +418,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql2 = "REFRESH a\tb" checkError( exception = parseException(sql2), - errorClass = "_LEGACY_ERROR_TEMP_0064", + condition = "_LEGACY_ERROR_TEMP_0064", parameters = Map("msg" -> errMsg1), context = ExpectedContext( fragment = sql2, @@ -427,7 +428,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql3 = "REFRESH a\nb" checkError( exception = parseException(sql3), - errorClass = "_LEGACY_ERROR_TEMP_0064", + condition = "_LEGACY_ERROR_TEMP_0064", parameters = Map("msg" -> errMsg1), context = ExpectedContext( fragment = sql3, @@ -437,7 +438,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql4 = "REFRESH a\rb" checkError( exception = parseException(sql4), - errorClass = "_LEGACY_ERROR_TEMP_0064", + condition = "_LEGACY_ERROR_TEMP_0064", parameters = Map("msg" -> errMsg1), context = ExpectedContext( fragment = sql4, @@ -447,7 +448,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql5 = "REFRESH a\r\nb" checkError( exception = parseException(sql5), - errorClass = "_LEGACY_ERROR_TEMP_0064", + condition = "_LEGACY_ERROR_TEMP_0064", parameters = Map("msg" -> errMsg1), context = ExpectedContext( fragment = sql5, @@ -457,7 +458,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql6 = "REFRESH @ $a$" checkError( exception = parseException(sql6), - errorClass = "_LEGACY_ERROR_TEMP_0064", + condition = "_LEGACY_ERROR_TEMP_0064", parameters = Map("msg" -> errMsg1), context = ExpectedContext( fragment = sql6, @@ -468,7 +469,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql7 = "REFRESH " checkError( exception = parseException(sql7), - errorClass = "_LEGACY_ERROR_TEMP_0064", + condition = "_LEGACY_ERROR_TEMP_0064", parameters = Map("msg" -> errMsg2), context = ExpectedContext( fragment = "REFRESH", @@ -478,7 +479,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { val sql8 = "REFRESH" checkError( exception = parseException(sql8), - errorClass = "_LEGACY_ERROR_TEMP_0064", + condition = "_LEGACY_ERROR_TEMP_0064", parameters = Map("msg" -> errMsg2), context = ExpectedContext( fragment = sql8, @@ -741,7 +742,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { |FROM v""".stripMargin checkError( exception = parseException(sql1), - errorClass = "_LEGACY_ERROR_TEMP_0064", + condition = "_LEGACY_ERROR_TEMP_0064", parameters = Map("msg" -> errMsg), context = ExpectedContext( fragment = sql1, @@ -763,7 +764,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { |FROM v""".stripMargin checkError( exception = parseException(sql2), - errorClass = "_LEGACY_ERROR_TEMP_0064", + condition = "_LEGACY_ERROR_TEMP_0064", parameters = Map("msg" -> errMsg), context = ExpectedContext( fragment = sql2, @@ -780,7 +781,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { s"CREATE TABLE target LIKE source TBLPROPERTIES (${TableCatalog.PROP_OWNER}='howdy')" checkError( exception = parseException(sql1), - errorClass = "UNSUPPORTED_FEATURE.SET_TABLE_PROPERTY", + condition = "UNSUPPORTED_FEATURE.SET_TABLE_PROPERTY", parameters = Map("property" -> TableCatalog.PROP_OWNER, "msg" -> "it will be set to the current user"), context = ExpectedContext( @@ -792,7 +793,7 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { s"CREATE TABLE target LIKE source TBLPROPERTIES (${TableCatalog.PROP_PROVIDER}='howdy')" checkError( exception = parseException(sql2), - errorClass = "UNSUPPORTED_FEATURE.SET_TABLE_PROPERTY", + condition = "UNSUPPORTED_FEATURE.SET_TABLE_PROPERTY", parameters = Map("property" -> TableCatalog.PROP_PROVIDER, "msg" -> "please use the USING clause to specify it"), context = ExpectedContext( @@ -880,4 +881,30 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { parser.parsePlan("SELECT\u30001") // Unicode ideographic space } // scalastyle:on + + test("Operator pipe SQL syntax") { + withSQLConf(SQLConf.OPERATOR_PIPE_SYNTAX_ENABLED.key -> "true") { + // Basic selection. + // Here we check that every parsed plan contains a projection and a source relation or + // inline table. + def checkPipeSelect(query: String): Unit = { + val plan: LogicalPlan = parser.parsePlan(query) + assert(plan.containsPattern(PROJECT)) + assert(plan.containsAnyPattern(UNRESOLVED_RELATION, LOCAL_RELATION)) + } + checkPipeSelect("TABLE t |> SELECT 1 AS X") + checkPipeSelect("TABLE t |> SELECT 1 AS X, 2 AS Y |> SELECT X + Y AS Z") + checkPipeSelect("VALUES (0), (1) tab(col) |> SELECT col * 2 AS result") + // Basic WHERE operators. + def checkPipeWhere(query: String): Unit = { + val plan: LogicalPlan = parser.parsePlan(query) + assert(plan.containsPattern(FILTER)) + assert(plan.containsAnyPattern(UNRESOLVED_RELATION, LOCAL_RELATION)) + } + checkPipeWhere("TABLE t |> WHERE X = 1") + checkPipeWhere("TABLE t |> SELECT X, LENGTH(Y) AS Z |> WHERE X + LENGTH(Y) < 4") + checkPipeWhere("TABLE t |> WHERE X = 1 AND Y = 2 |> WHERE X + Y = 3") + checkPipeWhere("VALUES (0), (1) tab(col) |> WHERE col < 1") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 4d2d465828924..a3cfdc5a240a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -878,7 +878,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession exception = intercept[SparkException] { sql(query).collect() }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> expectedErrMsg), matchPVals = true) } @@ -903,7 +903,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession exception = intercept[SparkException] { sql(query).collect() }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> expectedErrMsg), matchPVals = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 938a96a86b015..75f016d050de9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1082,7 +1082,7 @@ class AdaptiveQueryExecSuite val doExecute = PrivateMethod[Unit](Symbol("doExecute")) c.invokePrivate(doExecute()) }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "operating on canonicalized plan")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index a95bda9bf71df..cb97066098f20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -186,7 +186,7 @@ class ColumnTypeSuite extends SparkFunSuite { exception = intercept[SparkUnsupportedOperationException] { ColumnType(invalidType) }, - errorClass = "UNSUPPORTED_DATATYPE", + condition = "UNSUPPORTED_DATATYPE", parameters = Map("typeName" -> "\"INVALID TYPE NAME\"") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index ad755bf22ab09..0ba55382cd9a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -150,7 +150,7 @@ class InMemoryColumnarQuerySuite extends QueryTest spark.catalog.cacheTable("sizeTst") assert( spark.table("sizeTst").queryExecution.analyzed.stats.sizeInBytes > - spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + sqlConf.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 885286843a143..88ff51d0ff4cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -27,9 +27,9 @@ class PartitionBatchPruningSuite extends SharedSparkSession with AdaptiveSparkPl import testImplicits._ - private lazy val originalColumnBatchSize = spark.conf.get(SQLConf.COLUMN_BATCH_SIZE) + private lazy val originalColumnBatchSize = spark.conf.get(SQLConf.COLUMN_BATCH_SIZE.key) private lazy val originalInMemoryPartitionPruning = - spark.conf.get(SQLConf.IN_MEMORY_PARTITION_PRUNING) + spark.conf.get(SQLConf.IN_MEMORY_PARTITION_PRUNING.key) private val testArrayData = (1 to 100).map { key => Tuple1(Array.fill(key)(key)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala index 488b4d31bd923..cd099a2a94813 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala @@ -591,7 +591,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { if (policy == StoreAssignmentPolicy.ANSI) { checkError( exception = e, - errorClass = "DATATYPE_MISMATCH.INVALID_ROW_LEVEL_OPERATION_ASSIGNMENTS", + condition = "DATATYPE_MISMATCH.INVALID_ROW_LEVEL_OPERATION_ASSIGNMENTS", parameters = Map( "sqlExpr" -> "\"s.n_i = 1\", \"s.n_s = NULL\", \"s.n_i = -1\"", "errors" -> "\n- Multiple assignments for 's.n_i': 1, -1") @@ -599,7 +599,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { } else { checkError( exception = e, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "``", "colName" -> "`s`.`n_s`", @@ -701,7 +701,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { } checkError( exception = e, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", parameters = Map("tableName" -> "``", "colName" -> "`s`.`n_s`.`dn_l`") ) @@ -847,7 +847,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { } checkError( exception = e, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", parameters = Map("tableName" -> "``", "colName" -> "`s`.`n_s`.`dn_l`") ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignUpdateAssignmentsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignUpdateAssignmentsSuite.scala index 599f3e994ef8a..3c8ce44f8167b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignUpdateAssignmentsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignUpdateAssignmentsSuite.scala @@ -478,7 +478,7 @@ class AlignUpdateAssignmentsSuite extends AlignAssignmentsSuiteBase { if (policy == StoreAssignmentPolicy.ANSI) { checkError( exception = e, - errorClass = "DATATYPE_MISMATCH.INVALID_ROW_LEVEL_OPERATION_ASSIGNMENTS", + condition = "DATATYPE_MISMATCH.INVALID_ROW_LEVEL_OPERATION_ASSIGNMENTS", parameters = Map( "sqlExpr" -> "\"s.n_i = 1\", \"s.n_s = NULL\", \"s.n_i = -1\"", "errors" -> "\n- Multiple assignments for 's.n_i': 1, -1") @@ -486,7 +486,7 @@ class AlignUpdateAssignmentsSuite extends AlignAssignmentsSuiteBase { } else { checkError( exception = e, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "``", "colName" -> "`s`.`n_s`", @@ -538,7 +538,7 @@ class AlignUpdateAssignmentsSuite extends AlignAssignmentsSuiteBase { } checkError( exception = e, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", parameters = Map("tableName" -> "``", "colName" -> "`s`.`n_s`.`dn_l`") ) @@ -591,7 +591,7 @@ class AlignUpdateAssignmentsSuite extends AlignAssignmentsSuiteBase { } checkError( exception = e, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", parameters = Map("tableName" -> "``", "colName" -> "`s`.`n_s`.`dn_l`") ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationSuiteBase.scala index 2d4277e5499e8..64491f9ad9741 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationSuiteBase.scala @@ -51,7 +51,7 @@ trait AlterNamespaceSetLocationSuiteBase extends QueryTest with DDLCommandTestUt exception = intercept[SparkIllegalArgumentException] { sql(sqlText) }, - errorClass = "INVALID_EMPTY_LOCATION", + condition = "INVALID_EMPTY_LOCATION", parameters = Map("location" -> "")) } } @@ -66,7 +66,7 @@ trait AlterNamespaceSetLocationSuiteBase extends QueryTest with DDLCommandTestUt } checkError( exception = e, - errorClass = "INVALID_LOCATION", + condition = "INVALID_LOCATION", parameters = Map("location" -> "file:tmp")) } } @@ -77,7 +77,7 @@ trait AlterNamespaceSetLocationSuiteBase extends QueryTest with DDLCommandTestUt sql(s"ALTER DATABASE $catalog.$ns SET LOCATION 'loc'") } checkError(e, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> s"`$catalog`.`$ns`")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesParserSuite.scala index 9d70ceeef578e..70abfe8af5266 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesParserSuite.scala @@ -43,7 +43,7 @@ class AlterNamespaceSetPropertiesParserSuite extends AnalysisTest { val sql = "ALTER NAMESPACE my_db SET PROPERTIES('key_without_value', 'key_with_value'='x')" checkError( exception = parseException(parsePlan)(sql), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> "Values must be specified for key(s): [key_without_value]"), context = ExpectedContext( fragment = sql, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesSuiteBase.scala index d2f2d75d86ce9..3b0ac1d408234 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetPropertiesSuiteBase.scala @@ -49,7 +49,7 @@ trait AlterNamespaceSetPropertiesSuiteBase extends QueryTest with DDLCommandTest sql(s"ALTER DATABASE $catalog.$ns SET PROPERTIES ('d'='d')") } checkError(e, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> s"`$catalog`.`$ns`")) } @@ -88,7 +88,7 @@ trait AlterNamespaceSetPropertiesSuiteBase extends QueryTest with DDLCommandTest exception = intercept[ParseException] { sql(sqlText) }, - errorClass = "UNSUPPORTED_FEATURE.SET_NAMESPACE_PROPERTY", + condition = "UNSUPPORTED_FEATURE.SET_NAMESPACE_PROPERTY", parameters = Map("property" -> key, "msg" -> ".*"), sqlState = None, context = ExpectedContext( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceUnsetPropertiesParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceUnsetPropertiesParserSuite.scala index 72d307c816664..11e0f6c29bef5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceUnsetPropertiesParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceUnsetPropertiesParserSuite.scala @@ -54,7 +54,7 @@ class AlterNamespaceUnsetPropertiesParserSuite extends AnalysisTest with SharedS val sql = "ALTER NAMESPACE my_db UNSET PROPERTIES('key_without_value', 'key_with_value'='x')" checkError( exception = parseException(sql), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> "Values should not be specified for key(s): [key_with_value]"), context = ExpectedContext( fragment = sql, @@ -68,7 +68,7 @@ class AlterNamespaceUnsetPropertiesParserSuite extends AnalysisTest with SharedS val sql = s"ALTER $nsToken a.b.c UNSET $propToken IF EXISTS ('a', 'b', 'c')" checkError( exception = parseException(sql), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'IF'", "hint" -> ": missing '('") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceUnsetPropertiesSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceUnsetPropertiesSuiteBase.scala index c00f3f99f41f9..42550ef844361 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceUnsetPropertiesSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceUnsetPropertiesSuiteBase.scala @@ -56,7 +56,7 @@ trait AlterNamespaceUnsetPropertiesSuiteBase extends QueryTest with DDLCommandTe sql(s"ALTER NAMESPACE $catalog.$ns UNSET PROPERTIES ('d')") } checkError(e, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> s"`$catalog`.`$ns`")) } @@ -90,7 +90,7 @@ trait AlterNamespaceUnsetPropertiesSuiteBase extends QueryTest with DDLCommandTe exception = intercept[ParseException] { sql(sqlText) }, - errorClass = "UNSUPPORTED_FEATURE.SET_NAMESPACE_PROPERTY", + condition = "UNSUPPORTED_FEATURE.SET_NAMESPACE_PROPERTY", parameters = Map("property" -> key, "msg" -> ".*"), sqlState = None, context = ExpectedContext( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala index 3feb4f13c73f2..cb25942822f46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala @@ -234,9 +234,8 @@ trait AlterTableAddPartitionSuiteBase extends QueryTest with DDLCommandTestUtils exception = intercept[SparkNumberFormatException] { sql(s"ALTER TABLE $t ADD PARTITION (p='aaa')") }, - errorClass = "CAST_INVALID_INPUT", + condition = "CAST_INVALID_INPUT", parameters = Map( - "ansiConfig" -> "\"spark.sql.ansi.enabled\"", "expression" -> "'aaa'", "sourceType" -> "\"STRING\"", "targetType" -> "\"INT\""), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableClusterBySuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableClusterBySuiteBase.scala index 73a80cd910698..c0fd0a67d06aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableClusterBySuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableClusterBySuiteBase.scala @@ -83,7 +83,7 @@ trait AlterTableClusterBySuiteBase extends QueryTest with DDLCommandTestUtils { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $tbl CLUSTER BY (unknown)") }, - errorClass = "_LEGACY_ERROR_TEMP_3060", + condition = "_LEGACY_ERROR_TEMP_3060", parameters = Map("i" -> "unknown", "schema" -> """root diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionParserSuite.scala index 9b4b026480a16..2aa77dac711d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionParserSuite.scala @@ -90,7 +90,7 @@ class AlterTableDropPartitionParserSuite extends AnalysisTest with SharedSparkSe val sql = "ALTER VIEW table_name DROP PARTITION (p=1)" checkError( exception = parseException(parsePlan)(sql), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER VIEW ... DROP PARTITION"), context = ExpectedContext( fragment = sql, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionSuiteBase.scala index ef9ae47847405..279042f675cd5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionSuiteBase.scala @@ -146,7 +146,7 @@ trait AlterTableDropPartitionSuiteBase extends QueryTest with DDLCommandTestUtil "`test_catalog`.`ns`.`tbl`" } checkError(e, - errorClass = "PARTITIONS_NOT_FOUND", + condition = "PARTITIONS_NOT_FOUND", parameters = Map("partitionList" -> "PARTITION (`id` = 2)", "tableName" -> expectedTableName)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRecoverPartitionsParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRecoverPartitionsParserSuite.scala index 936b1a3dfdb20..babf490729564 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRecoverPartitionsParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRecoverPartitionsParserSuite.scala @@ -28,7 +28,7 @@ class AlterTableRecoverPartitionsParserSuite extends AnalysisTest with SharedSpa val sql = "ALTER TABLE RECOVER PARTITIONS" checkError( exception = parseException(parsePlan)(sql), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'PARTITIONS'", "hint" -> "")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameColumnParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameColumnParserSuite.scala index 62ee8aa57a760..1df4800fa7542 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameColumnParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameColumnParserSuite.scala @@ -37,12 +37,12 @@ class AlterTableRenameColumnParserSuite extends AnalysisTest with SharedSparkSes checkError( exception = parseException(parsePlan)( "ALTER TABLE t RENAME COLUMN test-col TO test"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "test-col")) checkError( exception = parseException(parsePlan)( "ALTER TABLE t RENAME COLUMN test TO test-col"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", parameters = Map("ident" -> "test-col")) } @@ -50,7 +50,7 @@ class AlterTableRenameColumnParserSuite extends AnalysisTest with SharedSparkSes checkError( exception = parseException(parsePlan)( "ALTER TABLE t RENAME COLUMN point.x to point.y"), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'.'", "hint" -> "")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameParserSuite.scala index 098750c929ecd..83d590e2bb35c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameParserSuite.scala @@ -44,13 +44,13 @@ class AlterTableRenameParserSuite extends AnalysisTest { val sql1 = "ALTER TABLE RENAME TO x.y.z" checkError( exception = parseException(parsePlan)(sql1), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'TO'", "hint" -> "")) val sql2 = "ALTER TABLE _ RENAME TO .z" checkError( exception = parseException(parsePlan)(sql2), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'.'", "hint" -> "")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionSuiteBase.scala index d91085956e330..905e6cfb9caaa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionSuiteBase.scala @@ -81,7 +81,7 @@ trait AlterTableRenamePartitionSuiteBase extends QueryTest with DDLCommandTestUt sql(s"ALTER TABLE $t PARTITION (id = 3) RENAME TO PARTITION (id = 2)") } checkError(e, - errorClass = "PARTITIONS_NOT_FOUND", + condition = "PARTITIONS_NOT_FOUND", parameters = Map("partitionList" -> "PARTITION (`id` = 3)", "tableName" -> parsed)) } @@ -103,7 +103,7 @@ trait AlterTableRenamePartitionSuiteBase extends QueryTest with DDLCommandTestUt sql(s"ALTER TABLE $t PARTITION (id = 1) RENAME TO PARTITION (id = 2)") } checkError(e, - errorClass = "PARTITIONS_ALREADY_EXIST", + condition = "PARTITIONS_ALREADY_EXIST", parameters = Map("partitionList" -> "PARTITION (`id` = 2)", "tableName" -> parsed)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableSetSerdeParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableSetSerdeParserSuite.scala index 8a3bfd47c6ea3..dcd3ad5681b06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableSetSerdeParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableSetSerdeParserSuite.scala @@ -29,7 +29,7 @@ class AlterTableSetSerdeParserSuite extends AnalysisTest with SharedSparkSession "WITH SERDEPROPERTIES('key_without_value', 'key_with_value'='x')" checkError( exception = parseException(parsePlan)(sql), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> "Values must be specified for key(s): [key_without_value]"), context = ExpectedContext( fragment = sql, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableSetTblPropertiesParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableSetTblPropertiesParserSuite.scala index 78abd1a8b7fd3..6b2c7069c4211 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableSetTblPropertiesParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableSetTblPropertiesParserSuite.scala @@ -45,7 +45,7 @@ class AlterTableSetTblPropertiesParserSuite extends AnalysisTest with SharedSpar val sql = "ALTER TABLE my_tab SET TBLPROPERTIES('key_without_value', 'key_with_value'='x')" checkError( exception = parseException(sql), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> "Values must be specified for key(s): [key_without_value]"), context = ExpectedContext( fragment = sql, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableSetTblPropertiesSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableSetTblPropertiesSuiteBase.scala index ac3c84dff718c..52a90497fdd37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableSetTblPropertiesSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableSetTblPropertiesSuiteBase.scala @@ -52,7 +52,7 @@ trait AlterTableSetTblPropertiesSuiteBase extends QueryTest with DDLCommandTestU exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> toSQLId(t)), context = ExpectedContext( fragment = t, @@ -96,7 +96,7 @@ trait AlterTableSetTblPropertiesSuiteBase extends QueryTest with DDLCommandTestU exception = intercept[ParseException] { sql(sqlText) }, - errorClass = "UNSUPPORTED_FEATURE.SET_TABLE_PROPERTY", + condition = "UNSUPPORTED_FEATURE.SET_TABLE_PROPERTY", parameters = Map( "property" -> key, "msg" -> keyParameters.getOrElse( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableUnsetTblPropertiesParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableUnsetTblPropertiesParserSuite.scala index 1e675a64f2235..c9582a75aa8cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableUnsetTblPropertiesParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableUnsetTblPropertiesParserSuite.scala @@ -55,7 +55,7 @@ class AlterTableUnsetTblPropertiesParserSuite extends AnalysisTest with SharedSp val sql = "ALTER TABLE my_tab UNSET TBLPROPERTIES('key_without_value', 'key_with_value'='x')" checkError( exception = parseException(sql), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> "Values should not be specified for key(s): [key_with_value]"), context = ExpectedContext( fragment = sql, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableUnsetTblPropertiesSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableUnsetTblPropertiesSuiteBase.scala index be8d85d2ef670..0013919fca08f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableUnsetTblPropertiesSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableUnsetTblPropertiesSuiteBase.scala @@ -52,7 +52,7 @@ trait AlterTableUnsetTblPropertiesSuiteBase extends QueryTest with DDLCommandTes exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> toSQLId(t)), context = ExpectedContext( fragment = t, @@ -116,7 +116,7 @@ trait AlterTableUnsetTblPropertiesSuiteBase extends QueryTest with DDLCommandTes exception = intercept[ParseException] { sql(sqlText) }, - errorClass = "UNSUPPORTED_FEATURE.SET_TABLE_PROPERTY", + condition = "UNSUPPORTED_FEATURE.SET_TABLE_PROPERTY", parameters = Map( "property" -> key, "msg" -> keyParameters.getOrElse( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala index 9c7f370278128..a5bb3058bedd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala @@ -53,7 +53,7 @@ trait CharVarcharDDLTestBase extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(alterSQL) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"CHAR(4)\"", "newType" -> "\"CHAR(5)\"", @@ -74,7 +74,7 @@ trait CharVarcharDDLTestBase extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"STRING\"", "newType" -> "\"CHAR(5)\"", @@ -95,7 +95,7 @@ trait CharVarcharDDLTestBase extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"INT\"", "newType" -> "\"CHAR(5)\"", @@ -124,7 +124,7 @@ trait CharVarcharDDLTestBase extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"VARCHAR(4)\"", "newType" -> "\"VARCHAR(3)\"", @@ -301,7 +301,7 @@ class DSV2CharVarcharDDLTestSuite extends CharVarcharDDLTestBase exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"CHAR(4)\"", "newType" -> "\"VARCHAR(3)\"", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceParserSuite.scala index 46ccc0b1312da..469e1a06920a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceParserSuite.scala @@ -70,7 +70,7 @@ class CreateNamespaceParserSuite extends AnalysisTest with SharedSparkSession { val sql1 = createNamespace("COMMENT 'namespace_comment'") checkError( exception = parseException(sql1), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "COMMENT"), context = ExpectedContext( fragment = sql1, @@ -80,7 +80,7 @@ class CreateNamespaceParserSuite extends AnalysisTest with SharedSparkSession { val sql2 = createNamespace("LOCATION '/home/user/db'") checkError( exception = parseException(sql2), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "LOCATION"), context = ExpectedContext( fragment = sql2, @@ -90,7 +90,7 @@ class CreateNamespaceParserSuite extends AnalysisTest with SharedSparkSession { val sql3 = createNamespace("WITH PROPERTIES ('a'='a', 'b'='b', 'c'='c')") checkError( exception = parseException(sql3), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "WITH PROPERTIES"), context = ExpectedContext( fragment = sql3, @@ -100,7 +100,7 @@ class CreateNamespaceParserSuite extends AnalysisTest with SharedSparkSession { val sql4 = createNamespace("WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')") checkError( exception = parseException(sql4), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "WITH DBPROPERTIES"), context = ExpectedContext( fragment = sql4, @@ -112,7 +112,7 @@ class CreateNamespaceParserSuite extends AnalysisTest with SharedSparkSession { val sql = "CREATE NAMESPACE a.b.c WITH PROPERTIES('key_without_value', 'key_with_value'='x')" checkError( exception = parseException(sql), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> "Values must be specified for key(s): [key_without_value]"), context = ExpectedContext( fragment = sql, @@ -127,7 +127,7 @@ class CreateNamespaceParserSuite extends AnalysisTest with SharedSparkSession { |WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')""".stripMargin checkError( exception = parseException(sql), - errorClass = "UNSUPPORTED_FEATURE.SET_PROPERTIES_AND_DBPROPERTIES", + condition = "UNSUPPORTED_FEATURE.SET_PROPERTIES_AND_DBPROPERTIES", parameters = Map.empty, context = ExpectedContext( fragment = sql, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceSuiteBase.scala index bfc32a761d57c..9733b104beecb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateNamespaceSuiteBase.scala @@ -74,7 +74,7 @@ trait CreateNamespaceSuiteBase extends QueryTest with DDLCommandTestUtils { exception = intercept[SparkIllegalArgumentException] { sql(sqlText) }, - errorClass = "INVALID_EMPTY_LOCATION", + condition = "INVALID_EMPTY_LOCATION", parameters = Map("location" -> "")) val uri = new Path(path).toUri sql(s"CREATE NAMESPACE $ns LOCATION '$uri'") @@ -99,7 +99,7 @@ trait CreateNamespaceSuiteBase extends QueryTest with DDLCommandTestUtils { sql(s"CREATE NAMESPACE $ns") } checkError(e, - errorClass = "SCHEMA_ALREADY_EXISTS", + condition = "SCHEMA_ALREADY_EXISTS", parameters = Map("schemaName" -> parsed)) // The following will be no-op since the namespace already exists. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 70276051defa9..176eb7c290764 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -64,7 +64,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val v2 = "INSERT OVERWRITE DIRECTORY USING parquet SELECT 1 as a" checkError( exception = parseException(v2), - errorClass = "_LEGACY_ERROR_TEMP_0049", + condition = "_LEGACY_ERROR_TEMP_0049", parameters = Map.empty, context = ExpectedContext( fragment = "INSERT OVERWRITE DIRECTORY USING parquet", @@ -99,7 +99,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { | OPTIONS ('path' '/tmp/file', a 1, b 0.1, c TRUE)""".stripMargin checkError( exception = parseException(v4), - errorClass = "_LEGACY_ERROR_TEMP_0049", + condition = "_LEGACY_ERROR_TEMP_0049", parameters = Map.empty, context = ExpectedContext( fragment = fragment4, @@ -113,7 +113,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { |(dt='2008-08-08', country='us') WITH TABLE table_name_2""".stripMargin checkError( exception = parseException(sql), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE EXCHANGE PARTITION"), context = ExpectedContext( fragment = sql, @@ -125,7 +125,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql = "ALTER TABLE table_name ARCHIVE PARTITION (dt='2008-08-08', country='us')" checkError( exception = parseException(sql), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE ARCHIVE PARTITION"), context = ExpectedContext( fragment = sql, @@ -137,7 +137,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql = "ALTER TABLE table_name UNARCHIVE PARTITION (dt='2008-08-08', country='us')" checkError( exception = parseException(sql), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE UNARCHIVE PARTITION"), context = ExpectedContext( fragment = sql, @@ -149,7 +149,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql1 = "ALTER TABLE table_name SET FILEFORMAT INPUTFORMAT 'test' OUTPUTFORMAT 'test'" checkError( exception = parseException(sql1), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE SET FILEFORMAT"), context = ExpectedContext( fragment = sql1, @@ -160,7 +160,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { "SET FILEFORMAT PARQUET" checkError( exception = parseException(sql2), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE SET FILEFORMAT"), context = ExpectedContext( fragment = sql2, @@ -172,7 +172,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql1 = "ALTER TABLE table_name TOUCH" checkError( exception = parseException(sql1), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE TOUCH"), context = ExpectedContext( fragment = sql1, @@ -182,7 +182,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql2 = "ALTER TABLE table_name TOUCH PARTITION (dt='2008-08-08', country='us')" checkError( exception = parseException(sql2), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE TOUCH"), context = ExpectedContext( fragment = sql2, @@ -194,7 +194,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql1 = "ALTER TABLE table_name COMPACT 'compaction_type'" checkError( exception = parseException(sql1), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE COMPACT"), context = ExpectedContext( fragment = sql1, @@ -206,7 +206,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { |COMPACT 'MAJOR'""".stripMargin checkError( exception = parseException(sql2), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE COMPACT"), context = ExpectedContext( fragment = sql2, @@ -218,7 +218,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql1 = "ALTER TABLE table_name CONCATENATE" checkError( exception = parseException(sql1), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE CONCATENATE"), context = ExpectedContext( fragment = sql1, @@ -228,7 +228,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql2 = "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') CONCATENATE" checkError( exception = parseException(sql2), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE CONCATENATE"), context = ExpectedContext( fragment = sql2, @@ -240,7 +240,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql1 = "ALTER TABLE table_name CLUSTERED BY (col_name) SORTED BY (col2_name) INTO 3 BUCKETS" checkError( exception = parseException(sql1), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE CLUSTERED BY"), context = ExpectedContext( fragment = sql1, @@ -250,7 +250,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql2 = "ALTER TABLE table_name CLUSTERED BY (col_name) INTO 3 BUCKETS" checkError( exception = parseException(sql2), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE CLUSTERED BY"), context = ExpectedContext( fragment = sql2, @@ -260,7 +260,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql3 = "ALTER TABLE table_name NOT CLUSTERED" checkError( exception = parseException(sql3), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE NOT CLUSTERED"), context = ExpectedContext( fragment = sql3, @@ -270,7 +270,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql4 = "ALTER TABLE table_name NOT SORTED" checkError( exception = parseException(sql4), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE NOT SORTED"), context = ExpectedContext( fragment = sql4, @@ -282,7 +282,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql1 = "ALTER TABLE table_name NOT SKEWED" checkError( exception = parseException(sql1), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE NOT SKEWED"), context = ExpectedContext( fragment = sql1, @@ -292,7 +292,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql2 = "ALTER TABLE table_name NOT STORED AS DIRECTORIES" checkError( exception = parseException(sql2), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE NOT STORED AS DIRECTORIES"), context = ExpectedContext( fragment = sql2, @@ -302,7 +302,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql3 = "ALTER TABLE table_name SET SKEWED LOCATION (col_name1=\"location1\"" checkError( exception = parseException(sql3), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE SET SKEWED LOCATION"), context = ExpectedContext( fragment = sql3, @@ -312,7 +312,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql4 = "ALTER TABLE table_name SKEWED BY (key) ON (1,5,6) STORED AS DIRECTORIES" checkError( exception = parseException(sql4), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE SKEWED BY"), context = ExpectedContext( fragment = sql4, @@ -326,7 +326,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { |COMMENT 'test_comment', new_col2 LONG COMMENT 'test_comment2') RESTRICT""".stripMargin checkError( exception = parseException(sql), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE REPLACE COLUMNS"), context = ExpectedContext( fragment = sql, @@ -351,7 +351,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { checkError( exception = parseException(sql1), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map( "message" -> "CREATE TEMPORARY TABLE ... AS ..., use CREATE TEMPORARY VIEW instead"), context = ExpectedContext( @@ -365,7 +365,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { |AS SELECT key, value FROM src ORDER BY key, value""".stripMargin checkError( exception = parseException(sql2), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map( "message" -> "Schema may not be specified in a Create Table As Select (CTAS) statement"), context = ExpectedContext( @@ -379,7 +379,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { |AS SELECT key, value FROM src ORDER BY key, value""".stripMargin checkError( exception = parseException(sql3), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "CREATE TABLE ... SKEWED BY"), context = ExpectedContext( fragment = sql3, @@ -392,7 +392,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { |FROM testData""".stripMargin checkError( exception = parseException(sql4), - errorClass = "_LEGACY_ERROR_TEMP_0048", + condition = "_LEGACY_ERROR_TEMP_0048", parameters = Map.empty, context = ExpectedContext( fragment = sql4, @@ -402,13 +402,13 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { test("Invalid interval term should throw AnalysisException") { val sql1 = "select interval '42-32' year to month" - val value1 = "Error parsing interval year-month string: " + - "requirement failed: month 32 outside range [0, 11]" val fragment1 = "'42-32' year to month" checkError( exception = parseException(sql1), - errorClass = "_LEGACY_ERROR_TEMP_0063", - parameters = Map("msg" -> value1), + condition = "INVALID_INTERVAL_FORMAT.INTERVAL_PARSING", + parameters = Map( + "input" -> "42-32", + "interval" -> "year-month"), context = ExpectedContext( fragment = fragment1, start = 16, @@ -418,7 +418,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val fragment2 = "'5 49:12:15' day to second" checkError( exception = parseException(sql2), - errorClass = "_LEGACY_ERROR_TEMP_0063", + condition = "_LEGACY_ERROR_TEMP_0063", parameters = Map("msg" -> "requirement failed: hour 49 outside range [0, 23]"), context = ExpectedContext( fragment = fragment2, @@ -429,7 +429,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val fragment3 = "'23:61:15' hour to second" checkError( exception = parseException(sql3), - errorClass = "_LEGACY_ERROR_TEMP_0063", + condition = "_LEGACY_ERROR_TEMP_0063", parameters = Map("msg" -> "requirement failed: minute 61 outside range [0, 59]"), context = ExpectedContext( fragment = fragment3, @@ -524,7 +524,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val v3 = "CREATE TEMPORARY VIEW a.b AS SELECT 1" checkError( exception = parseException(v3), - errorClass = "TEMP_VIEW_NAME_TOO_MANY_NAME_PARTS", + condition = "TEMP_VIEW_NAME_TOO_MANY_NAME_PARTS", parameters = Map("actualName" -> "`a`.`b`"), context = ExpectedContext( fragment = v3, @@ -579,7 +579,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val v1 = "CREATE VIEW view1 partitioned on (ds, hr) as select * from srcpart" checkError( exception = parseException(v1), - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "CREATE VIEW ... PARTITIONED ON"), context = ExpectedContext( fragment = v1, @@ -599,7 +599,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql1 = createViewStatement("COMMENT 'BLABLA'") checkError( exception = parseException(sql1), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "COMMENT"), context = ExpectedContext( fragment = sql1, @@ -609,7 +609,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql2 = createViewStatement("TBLPROPERTIES('prop1Key'=\"prop1Val\")") checkError( exception = parseException(sql2), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "TBLPROPERTIES"), context = ExpectedContext( fragment = sql2, @@ -655,7 +655,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql = "CREATE FUNCTION a as 'fun' USING OTHER 'o'" checkError( exception = parseException(sql), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> "CREATE FUNCTION with resource type 'other'"), context = ExpectedContext( fragment = sql, @@ -687,7 +687,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql1 = "DROP TEMPORARY FUNCTION a.b" checkError( exception = parseException(sql1), - errorClass = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", + condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "funcName" -> "`a`.`b`"), context = ExpectedContext( fragment = sql1, @@ -697,7 +697,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { val sql2 = "DROP TEMPORARY FUNCTION IF EXISTS a.b" checkError( exception = parseException(sql2), - errorClass = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", + condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "funcName" -> "`a`.`b`"), context = ExpectedContext( fragment = sql2, @@ -713,7 +713,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { |AS SELECT * FROM tab1""".stripMargin checkError( exception = parseException(sql), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> "TBLPROPERTIES can't coexist with CREATE TEMPORARY VIEW"), context = ExpectedContext( fragment = sql, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 5c1090c288ed5..8307326f17fcf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -88,7 +88,7 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSparkSession { exception = intercept[AnalysisException] { sql(s"CREATE TABLE $tabName (i INT, j STRING) STORED AS parquet") }, - errorClass = "NOT_SUPPORTED_COMMAND_WITHOUT_HIVE_SUPPORT", + condition = "NOT_SUPPORTED_COMMAND_WITHOUT_HIVE_SUPPORT", parameters = Map("cmd" -> "CREATE Hive TABLE (AS SELECT)") ) } @@ -108,7 +108,7 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSparkSession { |LOCATION '${tempDir.toURI}' """.stripMargin) }, - errorClass = "NOT_SUPPORTED_COMMAND_WITHOUT_HIVE_SUPPORT", + condition = "NOT_SUPPORTED_COMMAND_WITHOUT_HIVE_SUPPORT", parameters = Map("cmd" -> "CREATE Hive TABLE (AS SELECT)") ) } @@ -122,7 +122,7 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSparkSession { exception = intercept[AnalysisException] { sql("CREATE TABLE t STORED AS parquet SELECT 1 as a, 1 as b") }, - errorClass = "NOT_SUPPORTED_COMMAND_WITHOUT_HIVE_SUPPORT", + condition = "NOT_SUPPORTED_COMMAND_WITHOUT_HIVE_SUPPORT", parameters = Map("cmd" -> "CREATE Hive TABLE (AS SELECT)") ) @@ -131,7 +131,7 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSparkSession { exception = intercept[AnalysisException] { sql("CREATE TABLE t STORED AS parquet SELECT a, b from t1") }, - errorClass = "NOT_SUPPORTED_COMMAND_WITHOUT_HIVE_SUPPORT", + condition = "NOT_SUPPORTED_COMMAND_WITHOUT_HIVE_SUPPORT", parameters = Map("cmd" -> "CREATE Hive TABLE (AS SELECT)") ) } @@ -195,7 +195,7 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSparkSession { exception = intercept[AnalysisException] { sql("CREATE TABLE t LIKE s USING org.apache.spark.sql.hive.orc") }, - errorClass = "_LEGACY_ERROR_TEMP_1138", + condition = "_LEGACY_ERROR_TEMP_1138", parameters = Map.empty ) } @@ -209,7 +209,7 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSparkSession { } checkError( exception = e, - errorClass = "UNSUPPORTED_FEATURE.TABLE_OPERATION", + condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", sqlState = "0A000", parameters = Map("tableName" -> "`spark_catalog`.`default`.`t`", "operation" -> "ALTER COLUMN ... FIRST | AFTER")) @@ -379,7 +379,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[SparkRuntimeException] { sql(createStmt) }, - errorClass = "LOCATION_ALREADY_EXISTS", + condition = "LOCATION_ALREADY_EXISTS", parameters = Map( "location" -> expectedLoc, "identifier" -> s"`$SESSION_CATALOG_NAME`.`default`.`tab1`")) @@ -392,7 +392,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[SparkRuntimeException] { sql(s"CREATE TABLE IF NOT EXISTS tab1 LIKE tab2") }, - errorClass = "LOCATION_ALREADY_EXISTS", + condition = "LOCATION_ALREADY_EXISTS", parameters = Map( "location" -> expectedLoc, "identifier" -> s"`$SESSION_CATALOG_NAME`.`default`.`tab1`")) @@ -425,7 +425,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { if (userSpecifiedSchema.isEmpty && userSpecifiedPartitionCols.nonEmpty) { checkError( exception = intercept[AnalysisException](sql(sqlCreateTable)), - errorClass = "SPECIFY_PARTITION_IS_NOT_ALLOWED", + condition = "SPECIFY_PARTITION_IS_NOT_ALLOWED", parameters = Map.empty ) } else { @@ -529,7 +529,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql(s"CREATE TABLE t($c0 INT, $c1 INT) USING parquet") }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${c1.toLowerCase(Locale.ROOT)}`")) } } @@ -540,7 +540,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("CREATE TABLE tbl(a int, b string) USING json PARTITIONED BY (c)") }, - errorClass = "COLUMN_NOT_DEFINED_IN_TABLE", + condition = "COLUMN_NOT_DEFINED_IN_TABLE", parameters = Map( "colType" -> "partition", "colName" -> "`c`", @@ -553,7 +553,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("CREATE TABLE tbl(a int, b string) USING json CLUSTERED BY (c) INTO 4 BUCKETS") }, - errorClass = "COLUMN_NOT_DEFINED_IN_TABLE", + condition = "COLUMN_NOT_DEFINED_IN_TABLE", parameters = Map( "colType" -> "bucket", "colName" -> "`c`", @@ -568,7 +568,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql(s"CREATE TABLE t($c0 INT) USING parquet PARTITIONED BY ($c0, $c1)") }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${c1.toLowerCase(Locale.ROOT)}`")) } } @@ -581,7 +581,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql(s"CREATE TABLE t($c0 INT) USING parquet CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS") }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${c1.toLowerCase(Locale.ROOT)}`")) checkError( @@ -591,7 +591,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { | SORTED BY ($c0, $c1) INTO 2 BUCKETS """.stripMargin) }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${c1.toLowerCase(Locale.ROOT)}`")) } } @@ -618,7 +618,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { .option("path", dir2.getCanonicalPath) .saveAsTable("path_test") }, - errorClass = "_LEGACY_ERROR_TEMP_1160", + condition = "_LEGACY_ERROR_TEMP_1160", parameters = Map( "identifier" -> s"`$SESSION_CATALOG_NAME`.`default`.`path_test`", "existingTableLoc" -> ".*", @@ -687,7 +687,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql(s"CREATE VIEW t AS SELECT * FROM VALUES (1, 1) AS t($c0, $c1)") }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${c1.toLowerCase(Locale.ROOT)}`")) } } @@ -798,7 +798,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { |USING org.apache.spark.sql.execution.datasources.csv.CSVFileFormat |OPTIONS (PATH '${tmpFile.toURI}') """.stripMargin)}, - errorClass = "TEMP_TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TEMP_TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> "`testview`")) } } @@ -821,7 +821,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("ALTER TABLE tab1 RENAME TO default.tab2") }, - errorClass = "_LEGACY_ERROR_TEMP_1074", + condition = "_LEGACY_ERROR_TEMP_1074", parameters = Map( "oldName" -> "`tab1`", "newName" -> "`default`.`tab2`", @@ -850,7 +850,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("ALTER TABLE view1 RENAME TO default.tab2") }, - errorClass = "_LEGACY_ERROR_TEMP_1074", + condition = "_LEGACY_ERROR_TEMP_1074", parameters = Map( "oldName" -> "`view1`", "newName" -> "`default`.`tab2`", @@ -872,7 +872,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { checkAnswer(spark.table("tab1"), spark.range(10).toDF()) checkError( exception = intercept[AnalysisException] { spark.table("tab2") }, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`tab2`") ) } @@ -959,7 +959,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[ParseException] { sql(sql1) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE CLUSTERED BY"), context = ExpectedContext(fragment = sql1, start = 0, stop = 70)) val sql2 = "ALTER TABLE dbx.tab1 CLUSTERED BY (fuji) SORTED BY (grape) INTO 5 BUCKETS" @@ -967,7 +967,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[ParseException] { sql(sql2) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE CLUSTERED BY"), context = ExpectedContext(fragment = sql2, start = 0, stop = 72)) val sql3 = "ALTER TABLE dbx.tab1 NOT CLUSTERED" @@ -975,7 +975,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[ParseException] { sql(sql3) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE NOT CLUSTERED"), context = ExpectedContext(fragment = sql3, start = 0, stop = 33)) val sql4 = "ALTER TABLE dbx.tab1 NOT SORTED" @@ -983,7 +983,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[ParseException] { sql(sql4) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE NOT SORTED"), context = ExpectedContext(fragment = sql4, start = 0, stop = 30)) } @@ -999,7 +999,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[ParseException] { sql(sql1) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE SKEWED BY"), context = ExpectedContext(fragment = sql1, start = 0, stop = 113) ) @@ -1009,7 +1009,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[ParseException] { sql(sql2) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE SKEWED BY"), context = ExpectedContext(fragment = sql2, start = 0, stop = 113) ) @@ -1018,7 +1018,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[ParseException] { sql(sql3) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE NOT SKEWED"), context = ExpectedContext(fragment = sql3, start = 0, stop = 30) ) @@ -1027,7 +1027,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[ParseException] { sql(sql4) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER TABLE NOT STORED AS DIRECTORIES"), context = ExpectedContext(fragment = sql4, start = 0, stop = 45) ) @@ -1039,7 +1039,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[ParseException] { sql(sql1) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER VIEW ... ADD PARTITION"), context = ExpectedContext(fragment = sql1, start = 0, stop = 54) ) @@ -1051,7 +1051,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[ParseException] { sql(sql1) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "ALTER VIEW ... DROP PARTITION"), context = ExpectedContext(fragment = sql1, start = 0, stop = 51) ) @@ -1085,7 +1085,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { } checkError( exception = e, - errorClass = "WRONG_COMMAND_FOR_OBJECT_TYPE", + condition = "WRONG_COMMAND_FOR_OBJECT_TYPE", parameters = Map( "alternative" -> "DROP TABLE", "operation" -> "DROP VIEW", @@ -1125,21 +1125,21 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("DROP TEMPORARY FUNCTION year") }, - errorClass = "_LEGACY_ERROR_TEMP_1255", + condition = "_LEGACY_ERROR_TEMP_1255", parameters = Map("functionName" -> "year") ) checkError( exception = intercept[AnalysisException] { sql("DROP TEMPORARY FUNCTION YeAr") }, - errorClass = "_LEGACY_ERROR_TEMP_1255", + condition = "_LEGACY_ERROR_TEMP_1255", parameters = Map("functionName" -> "YeAr") ) checkError( exception = intercept[AnalysisException] { sql("DROP TEMPORARY FUNCTION `YeAr`") }, - errorClass = "_LEGACY_ERROR_TEMP_1255", + condition = "_LEGACY_ERROR_TEMP_1255", parameters = Map("functionName" -> "YeAr") ) } @@ -1216,7 +1216,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { checkError( exception = intercept[AnalysisException] { sql("CREATE TABLE tab1 USING json") }, - errorClass = "UNABLE_TO_INFER_SCHEMA", + condition = "UNABLE_TO_INFER_SCHEMA", parameters = Map("format" -> "JSON") ) @@ -1244,7 +1244,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { |CLUSTERED BY (nonexistentColumnA) SORTED BY (nonexistentColumnB) INTO 2 BUCKETS """.stripMargin) }, - errorClass = "SPECIFY_BUCKETING_IS_NOT_ALLOWED", + condition = "SPECIFY_BUCKETING_IS_NOT_ALLOWED", parameters = Map.empty ) } @@ -1271,7 +1271,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("CREATE TEMPORARY VIEW view1 (col1, col3) AS SELECT * FROM tab1") }, - errorClass = "CREATE_VIEW_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + condition = "CREATE_VIEW_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", parameters = Map( "viewName" -> "`view1`", "viewColumns" -> "`col1`, `col3`", @@ -1298,7 +1298,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { sql("CREATE TEMPORARY TABLE t_temp (c3 int, c4 string) USING JSON") } checkError(e, - errorClass = "TEMP_TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TEMP_TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> "`t_temp`")) } } @@ -1310,7 +1310,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { sql("CREATE TEMPORARY VIEW t_temp (c3 int, c4 string) USING JSON") } checkError(e, - errorClass = "TEMP_TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TEMP_TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> "`t_temp`")) } } @@ -1325,7 +1325,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { df.write.mode("append").partitionBy("a").saveAsTable("partitionedTable") }, - errorClass = "_LEGACY_ERROR_TEMP_1163", + condition = "_LEGACY_ERROR_TEMP_1163", parameters = Map( "tableName" -> "spark_catalog.default.partitionedtable", "specifiedPartCols" -> "a", @@ -1336,7 +1336,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { df.write.mode("append").partitionBy("b", "a").saveAsTable("partitionedTable") }, - errorClass = "_LEGACY_ERROR_TEMP_1163", + condition = "_LEGACY_ERROR_TEMP_1163", parameters = Map( "tableName" -> "spark_catalog.default.partitionedtable", "specifiedPartCols" -> "b, a", @@ -1347,7 +1347,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { df.write.mode("append").saveAsTable("partitionedTable") }, - errorClass = "_LEGACY_ERROR_TEMP_1163", + condition = "_LEGACY_ERROR_TEMP_1163", parameters = Map( "tableName" -> "spark_catalog.default.partitionedtable", "specifiedPartCols" -> "", "existingPartCols" -> "a, b") @@ -1934,7 +1934,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") }, - errorClass = "_LEGACY_ERROR_TEMP_1260", + condition = "_LEGACY_ERROR_TEMP_1260", parameters = Map( "tableType" -> ("org\\.apache\\.spark\\.sql\\.execution\\." + "datasources\\.v2\\.text\\.TextDataSourceV2.*"), @@ -1950,7 +1950,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("ALTER TABLE tmp_v ADD COLUMNS (c3 INT)") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> "`tmp_v`", "operation" -> "ALTER TABLE ... ADD COLUMNS"), @@ -1969,7 +1969,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("ALTER TABLE v1 ADD COLUMNS (c3 INT)") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> s"`$SESSION_CATALOG_NAME`.`default`.`v1`", "operation" -> "ALTER TABLE ... ADD COLUMNS"), @@ -1988,7 +1988,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("ALTER TABLE t1 ADD COLUMNS (c1 string)") }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`c1`")) } } @@ -2003,7 +2003,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`c1`")) } else { sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") @@ -2058,7 +2058,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql(s"SET ${config.CPUS_PER_TASK.key} = 4") }, - errorClass = "CANNOT_MODIFY_CONFIG", + condition = "CANNOT_MODIFY_CONFIG", parameters = Map( "key" -> "\"spark.task.cpus\"", "docroot" -> "https://spark.apache.org/docs/latest")) @@ -2120,7 +2120,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { } checkError( exception = e1, - errorClass = "DATA_SOURCE_NOT_FOUND", + condition = "DATA_SOURCE_NOT_FOUND", parameters = Map("provider" -> "unknown") ) @@ -2151,7 +2151,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[SparkException] { sql(s"ADD FILE $testDir") }, - errorClass = "UNSUPPORTED_ADD_FILE.DIRECTORY", + condition = "UNSUPPORTED_ADD_FILE.DIRECTORY", parameters = Map("path" -> s"file:${testDir.getCanonicalPath}/") ) } @@ -2163,7 +2163,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("REFRESH FUNCTION md5") }, - errorClass = "_LEGACY_ERROR_TEMP_1017", + condition = "_LEGACY_ERROR_TEMP_1017", parameters = Map( "name" -> "md5", "cmd" -> "REFRESH FUNCTION", "hintStr" -> ""), @@ -2172,7 +2172,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("REFRESH FUNCTION default.md5") }, - errorClass = "UNRESOLVED_ROUTINE", + condition = "UNRESOLVED_ROUTINE", parameters = Map( "routineName" -> "`default`.`md5`", "searchPath" -> "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`default`]"), @@ -2187,7 +2187,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("REFRESH FUNCTION func1") }, - errorClass = "_LEGACY_ERROR_TEMP_1017", + condition = "_LEGACY_ERROR_TEMP_1017", parameters = Map("name" -> "func1", "cmd" -> "REFRESH FUNCTION", "hintStr" -> ""), context = ExpectedContext( fragment = "func1", @@ -2203,7 +2203,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("REFRESH FUNCTION func1") }, - errorClass = "UNRESOLVED_ROUTINE", + condition = "UNRESOLVED_ROUTINE", parameters = Map( "routineName" -> "`func1`", "searchPath" -> "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`default`]"), @@ -2219,7 +2219,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("REFRESH FUNCTION func2") }, - errorClass = "UNRESOLVED_ROUTINE", + condition = "UNRESOLVED_ROUTINE", parameters = Map( "routineName" -> "`func2`", "searchPath" -> "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`default`]"), @@ -2235,7 +2235,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("REFRESH FUNCTION func1") }, - errorClass = "ROUTINE_NOT_FOUND", + condition = "ROUTINE_NOT_FOUND", parameters = Map("routineName" -> "`default`.`func1`") ) @@ -2248,7 +2248,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("REFRESH FUNCTION func1") }, - errorClass = "CANNOT_LOAD_FUNCTION_CLASS", + condition = "CANNOT_LOAD_FUNCTION_CLASS", parameters = Map( "className" -> "test.non.exists.udf", "functionName" -> "`spark_catalog`.`default`.`func1`" @@ -2267,7 +2267,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("REFRESH FUNCTION rand") }, - errorClass = "_LEGACY_ERROR_TEMP_1017", + condition = "_LEGACY_ERROR_TEMP_1017", parameters = Map("name" -> "rand", "cmd" -> "REFRESH FUNCTION", "hintStr" -> ""), context = ExpectedContext(fragment = "rand", start = 17, stop = 20) ) @@ -2282,12 +2282,23 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql(s"create table t(a int, b int generated always as (a + 1)) using parquet") }, - errorClass = "UNSUPPORTED_FEATURE.TABLE_OPERATION", + condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", parameters = Map("tableName" -> "`spark_catalog`.`default`.`t`", "operation" -> "generated columns") ) } + test("SPARK-48824: No identity columns with V1") { + checkError( + exception = intercept[AnalysisException] { + sql(s"create table t(a int, b bigint generated always as identity()) using parquet") + }, + condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", + parameters = Map("tableName" -> "`spark_catalog`.`default`.`t`", + "operation" -> "identity columns") + ) + } + test("SPARK-44837: Error when altering partition column in non-delta table") { withTable("t") { sql("CREATE TABLE t(i INT, j INT, k INT) USING parquet PARTITIONED BY (i, j)") @@ -2295,7 +2306,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("ALTER TABLE t ALTER COLUMN i COMMENT 'comment'") }, - errorClass = "CANNOT_ALTER_PARTITION_COLUMN", + condition = "CANNOT_ALTER_PARTITION_COLUMN", sqlState = "428FR", parameters = Map("tableName" -> "`spark_catalog`.`default`.`t`", "columnName" -> "`i`") @@ -2318,7 +2329,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql(alterInt) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"STRING COLLATE UTF8_LCASE\"", "originName" -> "`col`", @@ -2354,7 +2365,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql(alterMap) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"MAP\"", "originName" -> "`col`", @@ -2381,7 +2392,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("ALTER TABLE t1 ALTER COLUMN col TYPE STRING COLLATE UTF8_LCASE") }, - errorClass = "CANNOT_ALTER_PARTITION_COLUMN", + condition = "CANNOT_ALTER_PARTITION_COLUMN", sqlState = "428FR", parameters = Map("tableName" -> "`spark_catalog`.`default`.`t1`", "columnName" -> "`col`") ) @@ -2390,7 +2401,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { exception = intercept[AnalysisException] { sql("ALTER TABLE t2 ALTER COLUMN col TYPE STRING COLLATE UTF8_LCASE") }, - errorClass = "CANNOT_ALTER_COLLATION_BUCKET_COLUMN", + condition = "CANNOT_ALTER_COLLATION_BUCKET_COLUMN", sqlState = "428FR", parameters = Map("tableName" -> "`spark_catalog`.`default`.`t2`", "columnName" -> "`col`") ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DeclareVariableParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DeclareVariableParserSuite.scala index bc42937b93a92..02f1d012297bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DeclareVariableParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DeclareVariableParserSuite.scala @@ -179,7 +179,7 @@ class DeclareVariableParserSuite extends AnalysisTest with SharedSparkSession { exception = intercept[ParseException] { parsePlan("DECLARE VARIABLE IF NOT EXISTS var1 INT") }, - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'EXISTS'", "hint" -> "") ) @@ -189,7 +189,7 @@ class DeclareVariableParserSuite extends AnalysisTest with SharedSparkSession { exception = intercept[ParseException] { parsePlan(sqlText) }, - errorClass = "INVALID_SQL_SYNTAX.VARIABLE_TYPE_OR_DEFAULT_REQUIRED", + condition = "INVALID_SQL_SYNTAX.VARIABLE_TYPE_OR_DEFAULT_REQUIRED", parameters = Map.empty, context = ExpectedContext(fragment = sqlText, start = 0, stop = 20) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DescribeNamespaceSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DescribeNamespaceSuiteBase.scala index 6945352564e1e..36b17568d4716 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DescribeNamespaceSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DescribeNamespaceSuiteBase.scala @@ -43,7 +43,7 @@ trait DescribeNamespaceSuiteBase extends QueryTest with DDLCommandTestUtils { sql(s"DESCRIBE NAMESPACE EXTENDED $catalog.$ns") } checkError(e, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> s"`$catalog`.`$ns`")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DescribeTableParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DescribeTableParserSuite.scala index ee1b588741cd4..944f20bf8e924 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DescribeTableParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DescribeTableParserSuite.scala @@ -78,7 +78,7 @@ class DescribeTableParserSuite extends AnalysisTest { val sql = "DESCRIBE TABLE t PARTITION (ds='1970-01-01') col" checkError( exception = parseException(parsePlan)(sql), - errorClass = "UNSUPPORTED_FEATURE.DESC_TABLE_COLUMN_PARTITION", + condition = "UNSUPPORTED_FEATURE.DESC_TABLE_COLUMN_PARTITION", parameters = Map.empty, context = ExpectedContext( fragment = sql, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DropNamespaceSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DropNamespaceSuiteBase.scala index 2243517550b2c..2a7fe53a848e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DropNamespaceSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DropNamespaceSuiteBase.scala @@ -64,7 +64,7 @@ trait DropNamespaceSuiteBase extends QueryTest with DDLCommandTestUtils { sql(s"DROP NAMESPACE $catalog.unknown") } checkError(e, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> s"`$catalog`.`unknown`")) } @@ -78,7 +78,7 @@ trait DropNamespaceSuiteBase extends QueryTest with DDLCommandTestUtils { sql(s"DROP NAMESPACE $catalog.ns") } checkError(e, - errorClass = "SCHEMA_NOT_EMPTY", + condition = "SCHEMA_NOT_EMPTY", parameters = Map("schemaName" -> "`ns`")) sql(s"DROP TABLE $catalog.ns.table") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DropVariableParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DropVariableParserSuite.scala index f2af7e5dedca0..bc5e8c60ec812 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DropVariableParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DropVariableParserSuite.scala @@ -42,14 +42,14 @@ class DropVariableParserSuite extends AnalysisTest with SharedSparkSession { exception = intercept[ParseException] { parsePlan("DROP VARIABLE var1") }, - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'VARIABLE'", "hint" -> "") ) checkError( exception = intercept[ParseException] { parsePlan("DROP VAR var1") }, - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'VAR'", "hint" -> "") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 73bcde1e6e5be..5a4d7c86761fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -264,7 +264,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[ParseException] { parsePlan(sql) }, - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = parameters, context = context ) @@ -306,7 +306,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[SparkUnsupportedOperationException] { parseAndResolve(query) }, - errorClass = "_LEGACY_ERROR_TEMP_2067", + condition = "_LEGACY_ERROR_TEMP_2067", parameters = Map("transform" -> transform)) } } @@ -323,7 +323,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[SparkUnsupportedOperationException] { parseAndResolve(query) }, - errorClass = "UNSUPPORTED_FEATURE.MULTIPLE_BUCKET_TRANSFORMS", + condition = "UNSUPPORTED_FEATURE.MULTIPLE_BUCKET_TRANSFORMS", parameters = Map.empty) } } @@ -417,7 +417,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[ParseException] { parseAndResolve(v2) }, - errorClass = "_LEGACY_ERROR_TEMP_0032", + condition = "_LEGACY_ERROR_TEMP_0032", parameters = Map("pathOne" -> "/tmp/file", "pathTwo" -> "/tmp/file"), context = ExpectedContext( fragment = v2, @@ -763,7 +763,7 @@ class PlanResolutionSuite extends AnalysisTest { } checkError( e, - errorClass = "UNSUPPORTED_FEATURE.CATALOG_OPERATION", + condition = "UNSUPPORTED_FEATURE.CATALOG_OPERATION", parameters = Map("catalogName" -> "`testcat`", "operation" -> "views")) } @@ -1207,7 +1207,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[AnalysisException] { parseAndResolve(sql6, checkAnalysis = true) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`DEFAULT`", "proposal" -> "`i`, `s`"), context = ExpectedContext( fragment = "DEFAULT", @@ -1219,7 +1219,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[AnalysisException] { parseAndResolve(sql7, checkAnalysis = true) }, - errorClass = "NO_DEFAULT_COLUMN_VALUE_AVAILABLE", + condition = "NO_DEFAULT_COLUMN_VALUE_AVAILABLE", parameters = Map("colName" -> "`x`") ) } @@ -1267,7 +1267,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[AnalysisException] { parseAndResolve(sql2, checkAnalysis = true) }, - errorClass = "NO_DEFAULT_COLUMN_VALUE_AVAILABLE", + condition = "NO_DEFAULT_COLUMN_VALUE_AVAILABLE", parameters = Map("colName" -> "`x`") ) @@ -1276,7 +1276,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[AnalysisException] { parseAndResolve(sql3, checkAnalysis = true) }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", parameters = Map( "tableName" -> "`testcat`.`tab2`", "tableColumns" -> "`i`, `x`", @@ -1337,7 +1337,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[AnalysisException] { parseAndResolve(sql3) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`j`", @@ -1350,7 +1350,7 @@ class PlanResolutionSuite extends AnalysisTest { } checkError( exception = e2, - errorClass = "UNSUPPORTED_FEATURE.TABLE_OPERATION", + condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", sqlState = "0A000", parameters = Map("tableName" -> "`spark_catalog`.`default`.`v1Table`", "operation" -> "ALTER COLUMN with qualified column")) @@ -1359,7 +1359,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[AnalysisException] { parseAndResolve(s"ALTER TABLE $tblName ALTER COLUMN i SET NOT NULL") }, - errorClass = "UNSUPPORTED_FEATURE.TABLE_OPERATION", + condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", sqlState = "0A000", parameters = Map("tableName" -> "`spark_catalog`.`default`.`v1Table`", "operation" -> "ALTER COLUMN ... SET NOT NULL")) @@ -1407,7 +1407,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[ParseException] { parseAndResolve(sql) }, - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> "ALTER TABLE table ALTER COLUMN requires a TYPE, a SET/DROP, a COMMENT, or a FIRST/AFTER"), context = ExpectedContext(fragment = sql, start = 0, stop = 33)) @@ -1423,7 +1423,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[AnalysisException] { parseAndResolve(sql) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`I`", @@ -1944,7 +1944,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[AnalysisException] { parseAndResolve(mergeWithDefaultReferenceInMergeCondition, checkAnalysis = true) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`DEFAULT`", "proposal" -> "`target`.`i`, `source`.`i`, `target`.`s`, `source`.`s`"), context = ExpectedContext( @@ -1973,7 +1973,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[AnalysisException] { parseAndResolve(mergeWithDefaultReferenceAsPartOfComplexExpression) }, - errorClass = "DEFAULT_PLACEMENT_INVALID", + condition = "DEFAULT_PLACEMENT_INVALID", parameters = Map.empty) val mergeWithDefaultReferenceForNonNullableCol = @@ -1988,7 +1988,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[AnalysisException] { parseAndResolve(mergeWithDefaultReferenceForNonNullableCol) }, - errorClass = "NO_DEFAULT_COLUMN_VALUE_AVAILABLE", + condition = "NO_DEFAULT_COLUMN_VALUE_AVAILABLE", parameters = Map("colName" -> "`x`") ) @@ -2093,7 +2093,7 @@ class PlanResolutionSuite extends AnalysisTest { // resolve column `i` as it's ambiguous. checkError( exception = intercept[AnalysisException](parseAndResolve(sql2)), - errorClass = "AMBIGUOUS_REFERENCE", + condition = "AMBIGUOUS_REFERENCE", parameters = Map("name" -> "`i`", "referenceNames" -> referenceNames(target, "i")), context = ExpectedContext( fragment = "i", @@ -2109,7 +2109,7 @@ class PlanResolutionSuite extends AnalysisTest { // resolve column `s` as it's ambiguous. checkError( exception = intercept[AnalysisException](parseAndResolve(sql3)), - errorClass = "AMBIGUOUS_REFERENCE", + condition = "AMBIGUOUS_REFERENCE", parameters = Map("name" -> "`s`", "referenceNames" -> referenceNames(target, "s")), context = ExpectedContext( fragment = "s", @@ -2125,7 +2125,7 @@ class PlanResolutionSuite extends AnalysisTest { // resolve column `s` as it's ambiguous. checkError( exception = intercept[AnalysisException](parseAndResolve(sql4)), - errorClass = "AMBIGUOUS_REFERENCE", + condition = "AMBIGUOUS_REFERENCE", parameters = Map("name" -> "`s`", "referenceNames" -> referenceNames(target, "s")), context = ExpectedContext( fragment = "s", @@ -2141,7 +2141,7 @@ class PlanResolutionSuite extends AnalysisTest { // resolve column `s` as it's ambiguous. checkError( exception = intercept[AnalysisException](parseAndResolve(sql5)), - errorClass = "AMBIGUOUS_REFERENCE", + condition = "AMBIGUOUS_REFERENCE", parameters = Map("name" -> "`s`", "referenceNames" -> referenceNames(target, "s")), context = ExpectedContext( fragment = "s", @@ -2201,7 +2201,7 @@ class PlanResolutionSuite extends AnalysisTest { // update value in not matched by source clause can only reference the target table. checkError( exception = intercept[AnalysisException](parseAndResolve(sql7)), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> s"${toSQLId(source)}.`s`", "proposal" -> "`i`, `s`"), context = ExpectedContext( fragment = s"$source.s", @@ -2235,7 +2235,7 @@ class PlanResolutionSuite extends AnalysisTest { |WHEN MATCHED THEN UPDATE SET *""".stripMargin checkError( exception = intercept[AnalysisException](parseAndResolve(sql2)), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`s`", "proposal" -> "`i`, `x`"), context = ExpectedContext(fragment = sql2, start = 0, stop = 80)) @@ -2247,7 +2247,7 @@ class PlanResolutionSuite extends AnalysisTest { |WHEN NOT MATCHED THEN INSERT *""".stripMargin checkError( exception = intercept[AnalysisException](parseAndResolve(sql3)), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`s`", "proposal" -> "`i`, `x`"), context = ExpectedContext(fragment = sql3, start = 0, stop = 80)) @@ -2442,7 +2442,7 @@ class PlanResolutionSuite extends AnalysisTest { val sql = "CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING)" checkError( exception = parseException(parsePlan)(sql), - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "':'", "hint" -> "")) } @@ -2591,49 +2591,49 @@ class PlanResolutionSuite extends AnalysisTest { val sql1 = createTableHeader("TBLPROPERTIES('test' = 'test2')") checkError( exception = parseException(parsePlan)(sql1), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "TBLPROPERTIES"), context = ExpectedContext(fragment = sql1, start = 0, stop = 117)) val sql2 = createTableHeader("LOCATION '/tmp/file'") checkError( exception = parseException(parsePlan)(sql2), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "LOCATION"), context = ExpectedContext(fragment = sql2, start = 0, stop = 95)) val sql3 = createTableHeader("COMMENT 'a table'") checkError( exception = parseException(parsePlan)(sql3), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "COMMENT"), context = ExpectedContext(fragment = sql3, start = 0, stop = 89)) val sql4 = createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS") checkError( exception = parseException(parsePlan)(sql4), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "CLUSTERED BY"), context = ExpectedContext(fragment = sql4, start = 0, stop = 119)) val sql5 = createTableHeader("PARTITIONED BY (k int)") checkError( exception = parseException(parsePlan)(sql5), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "PARTITIONED BY"), context = ExpectedContext(fragment = sql5, start = 0, stop = 99)) val sql6 = createTableHeader("STORED AS parquet") checkError( exception = parseException(parsePlan)(sql6), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "STORED AS/BY"), context = ExpectedContext(fragment = sql6, start = 0, stop = 89)) val sql7 = createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'") checkError( exception = parseException(parsePlan)(sql7), - errorClass = "DUPLICATE_CLAUSES", + condition = "DUPLICATE_CLAUSES", parameters = Map("clauseName" -> "ROW FORMAT"), context = ExpectedContext(fragment = sql7, start = 0, stop = 163)) } @@ -2774,7 +2774,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[ParseException] { extractTableDesc(s4) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "STORED BY"), context = ExpectedContext( fragment = "STORED BY 'storage.handler.class.name'", @@ -2867,7 +2867,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[ParseException] { parsePlan(query) }, - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map( "message" -> "CREATE TEMPORARY TABLE ..., use CREATE TEMPORARY VIEW instead"), context = ExpectedContext(fragment = query, start = 0, stop = 48)) @@ -2939,7 +2939,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[ParseException] { parsePlan(query1) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "CREATE TABLE ... SKEWED BY"), context = ExpectedContext(fragment = query1, start = 0, stop = 72)) @@ -2948,7 +2948,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[ParseException] { parsePlan(query2) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "CREATE TABLE ... SKEWED BY"), context = ExpectedContext(fragment = query2, start = 0, stop = 96)) @@ -2957,7 +2957,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[ParseException] { parsePlan(query3) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "CREATE TABLE ... SKEWED BY"), context = ExpectedContext(fragment = query3, start = 0, stop = 118)) } @@ -3012,7 +3012,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[ParseException] { parsePlan(query1) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "STORED BY"), context = ExpectedContext( fragment = "STORED BY 'org.papachi.StorageHandler'", @@ -3024,7 +3024,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[ParseException] { parsePlan(query2) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> "STORED BY"), context = ExpectedContext( fragment = "STORED BY 'org.mamachi.StorageHandler' WITH SERDEPROPERTIES ('k1'='v1')", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowColumnsParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowColumnsParserSuite.scala index 17a6df87aa0e4..c93beaa10ec13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowColumnsParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowColumnsParserSuite.scala @@ -47,7 +47,7 @@ class ShowColumnsParserSuite extends AnalysisTest { test("illegal characters in unquoted identifier") { checkError( exception = parseException(parsePlan)("SHOW COLUMNS IN t FROM test-db"), - errorClass = "INVALID_IDENTIFIER", + condition = "INVALID_IDENTIFIER", sqlState = "42602", parameters = Map("ident" -> "test-db") ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowColumnsSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowColumnsSuiteBase.scala index c6f4e0bbd01a1..54bc10d0024f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowColumnsSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowColumnsSuiteBase.scala @@ -57,7 +57,7 @@ trait ShowColumnsSuiteBase extends QueryTest with DDLCommandTestUtils { exception = intercept[AnalysisException] { sql(s"SHOW COLUMNS IN tbl IN ns1") }, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`ns1`.`tbl`"), context = ExpectedContext(fragment = "tbl", start = 16, stop = 18) ) @@ -75,7 +75,7 @@ trait ShowColumnsSuiteBase extends QueryTest with DDLCommandTestUtils { exception = intercept[AnalysisException] { sql(sqlText1) }, - errorClass = "SHOW_COLUMNS_WITH_CONFLICT_NAMESPACE", + condition = "SHOW_COLUMNS_WITH_CONFLICT_NAMESPACE", parameters = Map( "namespaceA" -> s"`ns1`", "namespaceB" -> s"`ns`" @@ -88,7 +88,7 @@ trait ShowColumnsSuiteBase extends QueryTest with DDLCommandTestUtils { exception = intercept[AnalysisException] { sql(sqlText2) }, - errorClass = "SHOW_COLUMNS_WITH_CONFLICT_NAMESPACE", + condition = "SHOW_COLUMNS_WITH_CONFLICT_NAMESPACE", parameters = Map( "namespaceA" -> s"`${"ns".toUpperCase(Locale.ROOT)}`", "namespaceB" -> "`ns`" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowPartitionsParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowPartitionsParserSuite.scala index 3a5d57c5c7821..455689026d7d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowPartitionsParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowPartitionsParserSuite.scala @@ -47,7 +47,7 @@ class ShowPartitionsParserSuite extends AnalysisTest { test("empty values in non-optional partition specs") { checkError( exception = parseException(parsePlan)("SHOW PARTITIONS dbx.tab1 PARTITION (a='1', b)"), - errorClass = "INVALID_SQL_SYNTAX.EMPTY_PARTITION_VALUE", + condition = "INVALID_SQL_SYNTAX.EMPTY_PARTITION_VALUE", sqlState = "42000", parameters = Map("partKey" -> "`b`"), context = ExpectedContext( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowTablesSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowTablesSuiteBase.scala index 1890726a376ba..f6a5f6a7da26a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowTablesSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowTablesSuiteBase.scala @@ -168,7 +168,7 @@ trait ShowTablesSuiteBase extends QueryTest with DDLCommandTestUtils { exception = intercept[AnalysisException] { sql(s"SHOW TABLES IN $catalog.nonexist") }, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> s"`$catalog`.`nonexist`")) } @@ -177,7 +177,7 @@ trait ShowTablesSuiteBase extends QueryTest with DDLCommandTestUtils { exception = intercept[AnalysisException] { sql(s"SHOW TABLE EXTENDED IN $catalog.nonexist LIKE '*tbl*'") }, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> s"`$catalog`.`nonexist`")) } @@ -202,7 +202,7 @@ trait ShowTablesSuiteBase extends QueryTest with DDLCommandTestUtils { exception = intercept[AnalysisException] { sql(s"SHOW TABLE EXTENDED IN $catalog.$namespace LIKE '$table' PARTITION(id = 2)") }, - errorClass = "PARTITIONS_NOT_FOUND", + condition = "PARTITIONS_NOT_FOUND", parameters = Map( "partitionList" -> "PARTITION (`id` = 2)", "tableName" -> "`ns1`.`tbl`" @@ -220,7 +220,7 @@ trait ShowTablesSuiteBase extends QueryTest with DDLCommandTestUtils { sql(s"SHOW TABLE EXTENDED IN $catalog.$namespace LIKE '$table' PARTITION(id = 1)") } val (errorClass, parameters) = extendedPartInNonPartedTableError(catalog, namespace, table) - checkError(exception = e, errorClass = errorClass, parameters = parameters) + checkError(exception = e, condition = errorClass, parameters = parameters) } } @@ -261,7 +261,7 @@ trait ShowTablesSuiteBase extends QueryTest with DDLCommandTestUtils { sql(s"SHOW TABLE EXTENDED IN $catalog.$namespace " + s"LIKE '$table' PARTITION(id1 = 1)") }, - errorClass = "_LEGACY_ERROR_TEMP_1232", + condition = "_LEGACY_ERROR_TEMP_1232", parameters = Map( "specKeys" -> "id1", "partitionColumnNames" -> "id1, id2", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableParserSuite.scala index b903681e341f9..be37495acad05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableParserSuite.scala @@ -47,7 +47,7 @@ class TruncateTableParserSuite extends AnalysisTest { test("empty values in non-optional partition specs") { checkError( exception = parseException(parsePlan)("TRUNCATE TABLE dbx.tab1 PARTITION (a='1', b)"), - errorClass = "INVALID_SQL_SYNTAX.EMPTY_PARTITION_VALUE", + condition = "INVALID_SQL_SYNTAX.EMPTY_PARTITION_VALUE", sqlState = "42000", parameters = Map("partKey" -> "`b`"), context = ExpectedContext( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableSuiteBase.scala index 982c568d09a79..8c985ea1f0527 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/TruncateTableSuiteBase.scala @@ -181,7 +181,7 @@ trait TruncateTableSuiteBase extends QueryTest with DDLCommandTestUtils { exception = intercept[AnalysisException] { sql("TRUNCATE TABLE v0") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> "`spark_catalog`.`default`.`v0`", "operation" -> "TRUNCATE TABLE"), @@ -198,7 +198,7 @@ trait TruncateTableSuiteBase extends QueryTest with DDLCommandTestUtils { exception = intercept[AnalysisException] { sql("TRUNCATE TABLE v1") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> "`v1`", "operation" -> "TRUNCATE TABLE"), @@ -213,7 +213,7 @@ trait TruncateTableSuiteBase extends QueryTest with DDLCommandTestUtils { exception = intercept[AnalysisException] { sql(s"TRUNCATE TABLE $v2") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> "`global_temp`.`v2`", "operation" -> "TRUNCATE TABLE"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala index dac99c8ff7023..fea0d07278c1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala @@ -43,7 +43,7 @@ trait AlterTableAddPartitionSuiteBase extends command.AlterTableAddPartitionSuit exception = intercept[AnalysisException] { sql(s"ALTER TABLE $t ADD PARTITION (p1 = '')") }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map( "details" -> "The spec ([p1=]) contains an empty partition column value" ) @@ -155,7 +155,7 @@ trait AlterTableAddPartitionSuiteBase extends command.AlterTableAddPartitionSuit " PARTITION (id=2) LOCATION 'loc1'") } checkError(e, - errorClass = "PARTITIONS_ALREADY_EXIST", + condition = "PARTITIONS_ALREADY_EXIST", parameters = Map("partitionList" -> "PARTITION (`id` = 2)", "tableName" -> "`ns`.`tbl`")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableDropColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableDropColumnSuite.scala index 6370a834746a5..85c7e66bdbe57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableDropColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableDropColumnSuite.scala @@ -40,7 +40,7 @@ trait AlterTableDropColumnSuiteBase extends command.AlterTableDropColumnSuiteBas exception = intercept[AnalysisException]( sql(s"ALTER TABLE $t DROP COLUMN id") ), - errorClass = "UNSUPPORTED_FEATURE.TABLE_OPERATION", + condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", parameters = Map( "tableName" -> toSQLId(t), "operation" -> "DROP COLUMN" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableDropPartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableDropPartitionSuite.scala index 8d403429ca5d2..384aadfb3a6f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableDropPartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableDropPartitionSuite.scala @@ -83,7 +83,7 @@ class AlterTableDropPartitionSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE $t DROP PARTITION (p1 = '')") }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map("details" -> "The spec ([p1=]) contains an empty partition column value") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableRecoverPartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableRecoverPartitionsSuite.scala index b219e21a3d881..54c0e7883ccda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableRecoverPartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableRecoverPartitionsSuite.scala @@ -123,7 +123,7 @@ trait AlterTableRecoverPartitionsSuiteBase extends command.AlterTableRecoverPart } checkError( exception = exception, - errorClass = "NOT_A_PARTITIONED_TABLE", + condition = "NOT_A_PARTITIONED_TABLE", parameters = Map( "operation" -> "ALTER TABLE RECOVER PARTITIONS", "tableIdentWithDB" -> "`spark_catalog`.`default`.`tbl`") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableRenameColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableRenameColumnSuite.scala index 86b34311bfb3d..a6b43ad4d5a78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableRenameColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableRenameColumnSuite.scala @@ -39,7 +39,7 @@ trait AlterTableRenameColumnSuiteBase extends command.AlterTableRenameColumnSuit exception = intercept[AnalysisException]( sql(s"ALTER TABLE $t RENAME COLUMN col1 TO col3") ), - errorClass = "UNSUPPORTED_FEATURE.TABLE_OPERATION", + condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", parameters = Map( "tableName" -> toSQLId(t), "operation" -> "RENAME COLUMN" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableRenameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableRenameSuite.scala index dfbdc6a4ca78e..f8708d5bff25a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableRenameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableRenameSuite.scala @@ -41,7 +41,7 @@ trait AlterTableRenameSuiteBase extends command.AlterTableRenameSuiteBase with Q exception = intercept[AnalysisException] { sql(s"ALTER TABLE $src RENAME TO dst_ns.dst_tbl") }, - errorClass = "_LEGACY_ERROR_TEMP_1073", + condition = "_LEGACY_ERROR_TEMP_1073", parameters = Map("db" -> "src_ns", "newDb" -> "dst_ns") ) } @@ -75,7 +75,7 @@ trait AlterTableRenameSuiteBase extends command.AlterTableRenameSuiteBase with Q exception = intercept[SparkRuntimeException] { sql(s"ALTER TABLE $src RENAME TO ns.dst_tbl") }, - errorClass = "LOCATION_ALREADY_EXISTS", + condition = "LOCATION_ALREADY_EXISTS", parameters = Map( "location" -> s"'$dst_dir'", "identifier" -> toSQLId(dst))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableSetLocationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableSetLocationSuite.scala index 53b9853f36c8c..8f5af2e1f2e76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableSetLocationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableSetLocationSuite.scala @@ -93,7 +93,7 @@ trait AlterTableSetLocationSuiteBase extends command.AlterTableSetLocationSuiteB exception = intercept[AnalysisException] { sql(s"ALTER TABLE $t PARTITION (A='1', B='2') SET LOCATION '/path/to/part/ways3'") }, - errorClass = "_LEGACY_ERROR_TEMP_1231", + condition = "_LEGACY_ERROR_TEMP_1231", parameters = Map("key" -> "A", "tblName" -> "`spark_catalog`.`ns`.`tbl`") ) } @@ -127,7 +127,7 @@ trait AlterTableSetLocationSuiteBase extends command.AlterTableSetLocationSuiteB exception = intercept[AnalysisException] { sql(s"ALTER TABLE $t PARTITION (b='2') SET LOCATION '/mister/spark'") }, - errorClass = "_LEGACY_ERROR_TEMP_1232", + condition = "_LEGACY_ERROR_TEMP_1232", parameters = Map( "specKeys" -> "b", "partitionColumnNames" -> "a, b", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableSetSerdeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableSetSerdeSuite.scala index 6e4d6a8a0c8f0..259c4cb52a0fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableSetSerdeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableSetSerdeSuite.scala @@ -86,14 +86,14 @@ class AlterTableSetSerdeSuite extends AlterTableSetSerdeSuiteBase with CommandSu exception = intercept[AnalysisException] { sql(s"ALTER TABLE $t SET SERDE 'whatever'") }, - errorClass = "_LEGACY_ERROR_TEMP_1248", + condition = "_LEGACY_ERROR_TEMP_1248", parameters = Map.empty) checkError( exception = intercept[AnalysisException] { sql(s"ALTER TABLE $t SET SERDE 'org.apache.madoop' " + "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')") }, - errorClass = "_LEGACY_ERROR_TEMP_1248", + condition = "_LEGACY_ERROR_TEMP_1248", parameters = Map.empty) // set serde properties only @@ -133,14 +133,14 @@ class AlterTableSetSerdeSuite extends AlterTableSetSerdeSuiteBase with CommandSu exception = intercept[AnalysisException] { sql(s"ALTER TABLE $t PARTITION (a=1, b=2) SET SERDE 'whatever'") }, - errorClass = "_LEGACY_ERROR_TEMP_1247", + condition = "_LEGACY_ERROR_TEMP_1247", parameters = Map.empty) checkError( exception = intercept[AnalysisException] { sql(s"ALTER TABLE $t PARTITION (a=1, b=2) SET SERDE 'org.apache.madoop' " + "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')") }, - errorClass = "_LEGACY_ERROR_TEMP_1247", + condition = "_LEGACY_ERROR_TEMP_1247", parameters = Map.empty) // set serde properties only @@ -149,7 +149,7 @@ class AlterTableSetSerdeSuite extends AlterTableSetSerdeSuiteBase with CommandSu sql(s"ALTER TABLE $t PARTITION (a=1, b=2) " + "SET SERDEPROPERTIES ('k' = 'vvv', 'kay' = 'vee')") }, - errorClass = "_LEGACY_ERROR_TEMP_1247", + condition = "_LEGACY_ERROR_TEMP_1247", parameters = Map.empty) // set things without explicitly specifying database @@ -158,7 +158,7 @@ class AlterTableSetSerdeSuite extends AlterTableSetSerdeSuiteBase with CommandSu exception = intercept[AnalysisException] { sql(s"ALTER TABLE tbl PARTITION (a=1, b=2) SET SERDEPROPERTIES ('kay' = 'veee')") }, - errorClass = "_LEGACY_ERROR_TEMP_1247", + condition = "_LEGACY_ERROR_TEMP_1247", parameters = Map.empty) // table to alter does not exist diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DescribeTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DescribeTableSuite.scala index 02cf1958b9499..eaf016ac2fa9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DescribeTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DescribeTableSuite.scala @@ -47,7 +47,7 @@ trait DescribeTableSuiteBase extends command.DescribeTableSuiteBase sql(s"DESCRIBE TABLE $tbl PARTITION (id = 1)") } checkError(e, - errorClass = "PARTITIONS_NOT_FOUND", + condition = "PARTITIONS_NOT_FOUND", parameters = Map("partitionList" -> "PARTITION (`id` = 1)", "tableName" -> "`ns`.`table`")) } @@ -63,7 +63,7 @@ trait DescribeTableSuiteBase extends command.DescribeTableSuiteBase exception = intercept[AnalysisException] { sql(s"DESC $tbl key1").collect() }, - errorClass = "COLUMN_NOT_FOUND", + condition = "COLUMN_NOT_FOUND", parameters = Map( "colName" -> "`key1`", "caseSensitiveConfig" -> "\"spark.sql.caseSensitive\"" @@ -89,7 +89,7 @@ trait DescribeTableSuiteBase extends command.DescribeTableSuiteBase exception = intercept[AnalysisException] { sql(s"DESC $tbl KEY").collect() }, - errorClass = "COLUMN_NOT_FOUND", + condition = "COLUMN_NOT_FOUND", parameters = Map( "colName" -> "`KEY`", "caseSensitiveConfig" -> "\"spark.sql.caseSensitive\"" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DropNamespaceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DropNamespaceSuite.scala index cec72b8855291..f3f9369ea062c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DropNamespaceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DropNamespaceSuite.scala @@ -39,7 +39,7 @@ trait DropNamespaceSuiteBase extends command.DropNamespaceSuiteBase exception = intercept[AnalysisException] { sql(s"DROP NAMESPACE default") }, - errorClass = "UNSUPPORTED_FEATURE.DROP_DATABASE", + condition = "UNSUPPORTED_FEATURE.DROP_DATABASE", parameters = Map("database" -> s"`$catalog`.`default`") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowColumnsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowColumnsSuite.scala index e9459a224486c..3e8ac98dbf767 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowColumnsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowColumnsSuite.scala @@ -38,7 +38,7 @@ trait ShowColumnsSuiteBase extends command.ShowColumnsSuiteBase { exception = intercept[AnalysisException] { sql("SHOW COLUMNS IN tbl FROM a.b.c") }, - errorClass = "REQUIRES_SINGLE_PART_NAMESPACE", + condition = "REQUIRES_SINGLE_PART_NAMESPACE", parameters = Map( "sessionCatalog" -> catalog, "namespace" -> "`a`.`b`.`c`" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowCreateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowCreateTableSuite.scala index 18b5da0ca59fa..afbb943bf91f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowCreateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowCreateTableSuite.scala @@ -162,7 +162,7 @@ trait ShowCreateTableSuiteBase extends command.ShowCreateTableSuiteBase exception = intercept[AnalysisException] { getShowCreateDDL(t, true) }, - errorClass = "UNSUPPORTED_SHOW_CREATE_TABLE.ON_DATA_SOURCE_TABLE_WITH_AS_SERDE", + condition = "UNSUPPORTED_SHOW_CREATE_TABLE.ON_DATA_SOURCE_TABLE_WITH_AS_SERDE", sqlState = "0A000", parameters = Map("tableName" -> "`spark_catalog`.`ns1`.`tbl`") ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowNamespacesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowNamespacesSuite.scala index 85a46cfb93233..30189b46db4ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowNamespacesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowNamespacesSuite.scala @@ -49,7 +49,7 @@ trait ShowNamespacesSuiteBase extends command.ShowNamespacesSuiteBase { sql("SHOW NAMESPACES in dummy") } checkError(e, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> s"`$catalog`.`dummy`")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowPartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowPartitionsSuite.scala index 9863942c6ea19..0f64fa49f4862 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowPartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowPartitionsSuite.scala @@ -57,7 +57,7 @@ trait ShowPartitionsSuiteBase extends command.ShowPartitionsSuiteBase { exception = intercept[AnalysisException] { sql(s"SHOW PARTITIONS $view") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> s"`spark_catalog`.`default`.`view1`", "operation" -> "SHOW PARTITIONS" @@ -80,7 +80,7 @@ trait ShowPartitionsSuiteBase extends command.ShowPartitionsSuiteBase { exception = intercept[AnalysisException] { sql(s"SHOW PARTITIONS $viewName") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> "`test_view`", "operation" -> "SHOW PARTITIONS" @@ -124,7 +124,7 @@ class ShowPartitionsSuite extends ShowPartitionsSuiteBase with CommandSuiteBase exception = intercept[AnalysisException] { sql(s"SHOW PARTITIONS $viewName") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> "`test_view`", "operation" -> "SHOW PARTITIONS" @@ -163,7 +163,7 @@ class ShowPartitionsSuite extends ShowPartitionsSuiteBase with CommandSuiteBase exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "INVALID_PARTITION_OPERATION.PARTITION_SCHEMA_IS_EMPTY", + condition = "INVALID_PARTITION_OPERATION.PARTITION_SCHEMA_IS_EMPTY", parameters = Map("name" -> tableName)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowTablesSuite.scala index 9be802b5f1fea..001267a37d382 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowTablesSuite.scala @@ -56,7 +56,7 @@ trait ShowTablesSuiteBase extends command.ShowTablesSuiteBase with command.Tests exception = intercept[AnalysisException] { runShowTablesSql("SHOW TABLES FROM a.b", Seq()) }, - errorClass = "_LEGACY_ERROR_TEMP_1126", + condition = "_LEGACY_ERROR_TEMP_1126", parameters = Map("catalog" -> "a.b") ) } @@ -102,7 +102,7 @@ trait ShowTablesSuiteBase extends command.ShowTablesSuiteBase with command.Tests exception = intercept[AnalysisException] { sql(showTableCmd) }, - errorClass = "_LEGACY_ERROR_TEMP_1125", + condition = "_LEGACY_ERROR_TEMP_1125", parameters = Map.empty ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala index 5810a35ddcf8b..348b216aeb044 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala @@ -201,7 +201,7 @@ class TruncateTableSuite extends TruncateTableSuiteBase with CommandSuiteBase { exception = intercept[AnalysisException] { sql(s"TRUNCATE TABLE $t") }, - errorClass = "_LEGACY_ERROR_TEMP_1266", + condition = "_LEGACY_ERROR_TEMP_1266", parameters = Map("tableIdentWithDB" -> "`spark_catalog`.`ns`.`tbl`") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala index defa026c0e281..e3b6a9b5e6107 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala @@ -43,7 +43,7 @@ class AlterTableAddPartitionSuite exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "INVALID_PARTITION_OPERATION.PARTITION_MANAGEMENT_IS_UNSUPPORTED", + condition = "INVALID_PARTITION_OPERATION.PARTITION_MANAGEMENT_IS_UNSUPPORTED", parameters = Map("name" -> tableName), context = ExpectedContext( fragment = t, @@ -126,7 +126,7 @@ class AlterTableAddPartitionSuite " PARTITION (id=2) LOCATION 'loc1'") } checkError(e, - errorClass = "PARTITIONS_ALREADY_EXIST", + condition = "PARTITIONS_ALREADY_EXIST", parameters = Map("partitionList" -> "PARTITION (`id` = 2)", "tableName" -> "`test_catalog`.`ns`.`tbl`")) @@ -146,9 +146,8 @@ class AlterTableAddPartitionSuite exception = intercept[SparkNumberFormatException] { sql(s"ALTER TABLE $t ADD PARTITION (p='aaa')") }, - errorClass = "CAST_INVALID_INPUT", + condition = "CAST_INVALID_INPUT", parameters = Map( - "ansiConfig" -> "\"spark.sql.ansi.enabled\"", "expression" -> "'aaa'", "sourceType" -> "\"STRING\"", "targetType" -> "\"INT\""), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableDropColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableDropColumnSuite.scala index d541f1286e598..6cd9ed2628dbb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableDropColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableDropColumnSuite.scala @@ -45,7 +45,7 @@ class AlterTableDropColumnSuite exception = intercept[AnalysisException] { sql("ALTER TABLE does_not_exist DROP COLUMN id") }, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`does_not_exist`"), context = ExpectedContext(fragment = "does_not_exist", start = 12, stop = 25) ) @@ -127,7 +127,7 @@ class AlterTableDropColumnSuite exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`does_not_exist`", @@ -145,7 +145,7 @@ class AlterTableDropColumnSuite exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`point`.`does_not_exist`", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableDropPartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableDropPartitionSuite.scala index 2df7eebaecc81..35afb00ff0f38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableDropPartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableDropPartitionSuite.scala @@ -43,7 +43,7 @@ class AlterTableDropPartitionSuite exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "INVALID_PARTITION_OPERATION.PARTITION_MANAGEMENT_IS_UNSUPPORTED", + condition = "INVALID_PARTITION_OPERATION.PARTITION_MANAGEMENT_IS_UNSUPPORTED", parameters = Map("name" -> tableName), context = ExpectedContext( fragment = t, @@ -61,7 +61,7 @@ class AlterTableDropPartitionSuite exception = intercept[SparkUnsupportedOperationException] { sql(s"ALTER TABLE $t DROP PARTITION (id=1) PURGE") }, - errorClass = "UNSUPPORTED_FEATURE.PURGE_PARTITION", + condition = "UNSUPPORTED_FEATURE.PURGE_PARTITION", parameters = Map.empty ) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableRecoverPartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableRecoverPartitionsSuite.scala index ff6ff0df5306a..508b8e9d0339d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableRecoverPartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableRecoverPartitionsSuite.scala @@ -35,7 +35,7 @@ class AlterTableRecoverPartitionsSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE $t RECOVER PARTITIONS") }, - errorClass = "NOT_SUPPORTED_COMMAND_FOR_V2_TABLE", + condition = "NOT_SUPPORTED_COMMAND_FOR_V2_TABLE", parameters = Map("cmd" -> "ALTER TABLE ... RECOVER PARTITIONS") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableRenameColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableRenameColumnSuite.scala index a2ab63d9ebd85..6edf9ee4a10d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableRenameColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableRenameColumnSuite.scala @@ -45,7 +45,7 @@ class AlterTableRenameColumnSuite exception = intercept[AnalysisException] { sql("ALTER TABLE does_not_exist RENAME COLUMN col1 TO col3") }, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`does_not_exist`"), context = ExpectedContext(fragment = "does_not_exist", start = 12, stop = 25) ) @@ -153,7 +153,7 @@ class AlterTableRenameColumnSuite exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`does_not_exist`", @@ -171,7 +171,7 @@ class AlterTableRenameColumnSuite exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`point`.`does_not_exist`", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableReplaceColumnsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableReplaceColumnsSuite.scala index 599820d7622d4..4afe294549f45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableReplaceColumnsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableReplaceColumnsSuite.scala @@ -35,7 +35,7 @@ class AlterTableReplaceColumnsSuite exception = intercept[ParseException] { sql(sql1) }, - errorClass = "UNSUPPORTED_DEFAULT_VALUE.WITHOUT_SUGGESTION", + condition = "UNSUPPORTED_DEFAULT_VALUE.WITHOUT_SUGGESTION", parameters = Map.empty, context = ExpectedContext(sql1, 0, 48) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableSetLocationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableSetLocationSuite.scala index 13f6b8d5b33bb..feb00ce0ec69f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableSetLocationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableSetLocationSuite.scala @@ -60,7 +60,7 @@ class AlterTableSetLocationSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE $t PARTITION(ds='2017-06-10') SET LOCATION 's3://bucket/path'") }, - errorClass = "_LEGACY_ERROR_TEMP_1045", + condition = "_LEGACY_ERROR_TEMP_1045", parameters = Map.empty ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableSetSerdeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableSetSerdeSuite.scala index d17bab99d01fe..971a5cd077bc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableSetSerdeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableSetSerdeSuite.scala @@ -40,7 +40,7 @@ class AlterTableSetSerdeSuite extends command.AlterTableSetSerdeSuiteBase with C exception = intercept[AnalysisException] { sql(s"ALTER TABLE $t SET SERDEPROPERTIES ('columns'='foo,bar', 'field.delim' = ',')") }, - errorClass = "NOT_SUPPORTED_COMMAND_FOR_V2_TABLE", + condition = "NOT_SUPPORTED_COMMAND_FOR_V2_TABLE", sqlState = "0A000", parameters = Map("cmd" -> "ALTER TABLE ... SET [SERDE|SERDEPROPERTIES]") ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DescribeTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DescribeTableSuite.scala index cfd26c09bf3e5..9cd7f0d8aade6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DescribeTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DescribeTableSuite.scala @@ -106,7 +106,7 @@ class DescribeTableSuite extends command.DescribeTableSuiteBase exception = intercept[AnalysisException] { sql(query).collect() }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`key1`", @@ -137,7 +137,7 @@ class DescribeTableSuite extends command.DescribeTableSuiteBase exception = intercept[AnalysisException] { sql(query).collect() }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`KEY`", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropTableSuite.scala index 83bded7ab4f52..ffc2c6c679a8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DropTableSuite.scala @@ -35,7 +35,7 @@ class DropTableSuite extends command.DropTableSuiteBase with CommandSuiteBase { exception = intercept[SparkUnsupportedOperationException] { sql(s"DROP TABLE $catalog.ns.tbl PURGE") }, - errorClass = "UNSUPPORTED_FEATURE.PURGE_TABLE", + condition = "UNSUPPORTED_FEATURE.PURGE_TABLE", parameters = Map.empty ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/MsckRepairTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/MsckRepairTableSuite.scala index 381e55b49393c..73764e88bffa9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/MsckRepairTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/MsckRepairTableSuite.scala @@ -36,7 +36,7 @@ class MsckRepairTableSuite exception = intercept[AnalysisException] { sql(s"MSCK REPAIR TABLE $t") }, - errorClass = "NOT_SUPPORTED_COMMAND_FOR_V2_TABLE", + condition = "NOT_SUPPORTED_COMMAND_FOR_V2_TABLE", parameters = Map("cmd" -> "MSCK REPAIR TABLE") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala index 8e1bb08162e3e..44a1bcad46a03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala @@ -44,7 +44,7 @@ class ShowNamespacesSuite extends command.ShowNamespacesSuiteBase with CommandSu exception = intercept[AnalysisException] { sql("SHOW NAMESPACES") }, - errorClass = "_LEGACY_ERROR_TEMP_1184", + condition = "_LEGACY_ERROR_TEMP_1184", parameters = Map( "plugin" -> "testcat_no_namespace", "ability" -> "namespaces" @@ -58,7 +58,7 @@ class ShowNamespacesSuite extends command.ShowNamespacesSuiteBase with CommandSu exception = intercept[AnalysisException] { sql("SHOW NAMESPACES in testcat_no_namespace") }, - errorClass = "_LEGACY_ERROR_TEMP_1184", + condition = "_LEGACY_ERROR_TEMP_1184", parameters = Map( "plugin" -> "testcat_no_namespace", "ability" -> "namespaces" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowPartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowPartitionsSuite.scala index 203ef4314ad25..1fb1c48890607 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowPartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowPartitionsSuite.scala @@ -38,7 +38,7 @@ class ShowPartitionsSuite extends command.ShowPartitionsSuiteBase with CommandSu exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "INVALID_PARTITION_OPERATION.PARTITION_SCHEMA_IS_EMPTY", + condition = "INVALID_PARTITION_OPERATION.PARTITION_SCHEMA_IS_EMPTY", parameters = Map("name" -> tableName), context = ExpectedContext( fragment = t, @@ -61,7 +61,7 @@ class ShowPartitionsSuite extends command.ShowPartitionsSuiteBase with CommandSu exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "INVALID_PARTITION_OPERATION.PARTITION_MANAGEMENT_IS_UNSUPPORTED", + condition = "INVALID_PARTITION_OPERATION.PARTITION_MANAGEMENT_IS_UNSUPPORTED", parameters = Map("name" -> tableName), context = ExpectedContext( fragment = table, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/TruncateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/TruncateTableSuite.scala index 36b994c21a083..972511a470465 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/TruncateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/TruncateTableSuite.scala @@ -41,7 +41,7 @@ class TruncateTableSuite extends command.TruncateTableSuiteBase with CommandSuit exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "INVALID_PARTITION_OPERATION.PARTITION_MANAGEMENT_IS_UNSUPPORTED", + condition = "INVALID_PARTITION_OPERATION.PARTITION_MANAGEMENT_IS_UNSUPPORTED", parameters = Map("name" -> tableName), context = ExpectedContext( fragment = t, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala index 7f886940473de..fd9d31e7a594d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala @@ -128,7 +128,7 @@ class DataSourceSuite extends SharedSparkSession with PrivateMethodTester { enableGlobbing = true ) ), - errorClass = "PATH_NOT_FOUND", + condition = "PATH_NOT_FOUND", parameters = Map("path" -> nonExistentPath.toString) ) } @@ -173,7 +173,7 @@ class DataSourceSuite extends SharedSparkSession with PrivateMethodTester { new File(uuid, "file3").getAbsolutePath, uuid).rdd }, - errorClass = "PATH_NOT_FOUND", + condition = "PATH_NOT_FOUND", parameters = Map("path" -> "file:.*"), matchPVals = true ) @@ -187,7 +187,7 @@ class DataSourceSuite extends SharedSparkSession with PrivateMethodTester { exception = intercept[AnalysisException] { spark.read.format("text").load(s"$nonExistentBasePath/*") }, - errorClass = "PATH_NOT_FOUND", + condition = "PATH_NOT_FOUND", parameters = Map("path" -> s"file:$nonExistentBasePath/*") ) @@ -200,7 +200,7 @@ class DataSourceSuite extends SharedSparkSession with PrivateMethodTester { exception = intercept[AnalysisException] { spark.read.json(s"${baseDir.getAbsolutePath}/*/*-xyz.json").rdd }, - errorClass = "PATH_NOT_FOUND", + condition = "PATH_NOT_FOUND", parameters = Map("path" -> s"file:${baseDir.getAbsolutePath}/*/*-xyz.json") ) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index 21623f94c8baf..31b7380889158 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -137,7 +137,7 @@ class FileIndexSuite extends SharedSparkSession { exception = intercept[SparkRuntimeException] { fileIndex.partitionSpec() }, - errorClass = "_LEGACY_ERROR_TEMP_2058", + condition = "_LEGACY_ERROR_TEMP_2058", parameters = Map("value" -> "foo", "dataType" -> "IntegerType", "columnName" -> "a") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala index 185a7cf5a6b40..880f1dd9af8f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala @@ -243,7 +243,7 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select("name", METADATA_FILE_NAME).collect() }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`file_name`", "fields" -> "`id`, `university`"), context = ExpectedContext(fragment = "select", callSitePattern = getCurrentClassCallSitePattern)) @@ -309,7 +309,7 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.metadataColumn("foo") }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`foo`", "proposal" -> "`_metadata`")) // Name exists, but does not reference a metadata column @@ -317,7 +317,7 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.metadataColumn("name") }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`name`", "proposal" -> "`_metadata`")) } @@ -525,7 +525,7 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select("name", "_metadata.file_name").collect() }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`file_name`", "fields" -> "`id`, `university`"), context = ExpectedContext( fragment = "select", @@ -535,7 +535,7 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select("name", "_METADATA.file_NAME").collect() }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map("fieldName" -> "`file_NAME`", "fields" -> "`id`, `university`"), context = ExpectedContext( fragment = "select", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala index fefb16a351fdb..c798196c4f0ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala @@ -101,7 +101,7 @@ class OrcReadSchemaSuite override def beforeAll(): Unit = { super.beforeAll() - originalConf = spark.conf.get(SQLConf.ORC_VECTORIZED_READER_ENABLED) + originalConf = sqlConf.getConf(SQLConf.ORC_VECTORIZED_READER_ENABLED) spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, "false") } @@ -126,7 +126,7 @@ class VectorizedOrcReadSchemaSuite override def beforeAll(): Unit = { super.beforeAll() - originalConf = spark.conf.get(SQLConf.ORC_VECTORIZED_READER_ENABLED) + originalConf = sqlConf.getConf(SQLConf.ORC_VECTORIZED_READER_ENABLED) spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, "true") } @@ -169,7 +169,7 @@ class ParquetReadSchemaSuite override def beforeAll(): Unit = { super.beforeAll() - originalConf = spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + originalConf = sqlConf.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "false") } @@ -193,7 +193,7 @@ class VectorizedParquetReadSchemaSuite override def beforeAll(): Unit = { super.beforeAll() - originalConf = spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + originalConf = sqlConf.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") } @@ -217,7 +217,7 @@ class MergedParquetReadSchemaSuite override def beforeAll(): Unit = { super.beforeAll() - originalConf = spark.conf.get(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED) + originalConf = sqlConf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED) spark.conf.set(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key, "true") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala index 3762241719acd..26962d89452ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala @@ -101,7 +101,7 @@ class SaveIntoDataSourceCommandSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { dataSource.planForWriting(SaveMode.ErrorIfExists, df.logicalPlan) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map("columnName" -> "`col`", "columnType" -> s"\"${testCase._2}\"", "format" -> ".*JdbcRelationProvider.*"), matchPVals = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala index 5c118ac12b72a..deb62eb3ac234 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala @@ -168,7 +168,7 @@ class BinaryFileFormatSuite extends QueryTest with SharedSparkSession { .format(BINARY_FILE) .save(s"$tmpDir/test_save") }, - errorClass = "_LEGACY_ERROR_TEMP_2075", + condition = "_LEGACY_ERROR_TEMP_2075", parameters = Map.empty) } } @@ -346,7 +346,7 @@ class BinaryFileFormatSuite extends QueryTest with SharedSparkSession { } test("fail fast and do not attempt to read if a file is too big") { - assert(spark.conf.get(SOURCES_BINARY_FILE_MAX_LENGTH) === Int.MaxValue) + assert(sqlConf.getConf(SOURCES_BINARY_FILE_MAX_LENGTH) === Int.MaxValue) withTempPath { file => val path = file.getPath val content = "123".getBytes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index b1b5c882f4e97..e2d1d9b05c3c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -266,7 +266,7 @@ abstract class CSVSuite spark.read.format("csv").option("charset", "1-9588-osi") .load(testFile(carsFile8859)) }, - errorClass = "INVALID_PARAMETER_VALUE.CHARSET", + condition = "INVALID_PARAMETER_VALUE.CHARSET", parameters = Map( "charset" -> "1-9588-osi", "functionName" -> toSQLId("CSVOptions"), @@ -388,13 +388,13 @@ abstract class CSVSuite } checkErrorMatchPVals( exception = e1, - errorClass = "FAILED_READ_FILE.NO_HINT", + condition = "FAILED_READ_FILE.NO_HINT", parameters = Map("path" -> s".*$carsFile.*")) val e2 = e1.getCause.asInstanceOf[SparkException] assert(e2.getErrorClass == "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION") checkError( exception = e2.getCause.asInstanceOf[SparkRuntimeException], - errorClass = "MALFORMED_CSV_RECORD", + condition = "MALFORMED_CSV_RECORD", parameters = Map("badRecord" -> "2015,Chevy,Volt") ) } @@ -650,7 +650,7 @@ abstract class CSVSuite .csv(csvDir) } }, - errorClass = "INVALID_PARAMETER_VALUE.CHARSET", + condition = "INVALID_PARAMETER_VALUE.CHARSET", parameters = Map( "charset" -> "1-9588-osi", "functionName" -> toSQLId("CSVOptions"), @@ -1269,7 +1269,7 @@ abstract class CSVSuite } checkErrorMatchPVals( exception = ex, - errorClass = "TASK_WRITE_FAILED", + condition = "TASK_WRITE_FAILED", parameters = Map("path" -> s".*${path.getName}.*")) val msg = ex.getCause.getMessage assert( @@ -1509,7 +1509,7 @@ abstract class CSVSuite .csv(testFile(valueMalformedFile)) .collect() }, - errorClass = "_LEGACY_ERROR_TEMP_1097", + condition = "_LEGACY_ERROR_TEMP_1097", parameters = Map.empty ) } @@ -1523,7 +1523,7 @@ abstract class CSVSuite } checkErrorMatchPVals( exception = e, - errorClass = "FAILED_READ_FILE.NO_HINT", + condition = "FAILED_READ_FILE.NO_HINT", parameters = Map("path" -> s".*${inputFile.getName}.*") ) assert(e.getCause.isInstanceOf[EOFException]) @@ -1533,7 +1533,7 @@ abstract class CSVSuite } checkErrorMatchPVals( exception = e2, - errorClass = "FAILED_READ_FILE.NO_HINT", + condition = "FAILED_READ_FILE.NO_HINT", parameters = Map("path" -> s".*${inputFile.getName}.*") ) assert(e2.getCause.getCause.getCause.isInstanceOf[EOFException]) @@ -1557,7 +1557,7 @@ abstract class CSVSuite exception = intercept[SparkException] { df.collect() }, - errorClass = "FAILED_READ_FILE.FILE_NOT_EXIST", + condition = "FAILED_READ_FILE.FILE_NOT_EXIST", parameters = Map("path" -> s".*$dir.*") ) } @@ -1705,7 +1705,7 @@ abstract class CSVSuite } checkError( exception = exception, - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map("badRecord" -> "[null]", "failFastMode" -> "FAILFAST")) assert(exception.getCause.getMessage.contains("""input string: "10u12"""")) @@ -1794,7 +1794,7 @@ abstract class CSVSuite spark.read.schema(schema).csv(testFile(valueMalformedFile)) .select("_corrupt_record").collect() }, - errorClass = "UNSUPPORTED_FEATURE.QUERY_ONLY_CORRUPT_RECORD_COLUMN", + condition = "UNSUPPORTED_FEATURE.QUERY_ONLY_CORRUPT_RECORD_COLUMN", parameters = Map.empty ) // workaround @@ -2013,7 +2013,7 @@ abstract class CSVSuite } checkErrorMatchPVals( exception = exception, - errorClass = "FAILED_READ_FILE.NO_HINT", + condition = "FAILED_READ_FILE.NO_HINT", parameters = Map("path" -> s".*${path.getCanonicalPath}.*")) assert(exception.getCause.getMessage.contains("CSV header does not conform to the schema")) @@ -2029,7 +2029,7 @@ abstract class CSVSuite } checkErrorMatchPVals( exception = exceptionForShortSchema, - errorClass = "FAILED_READ_FILE.NO_HINT", + condition = "FAILED_READ_FILE.NO_HINT", parameters = Map("path" -> s".*${path.getCanonicalPath}.*")) assert(exceptionForShortSchema.getCause.getMessage.contains( "Number of column in CSV header is not equal to number of fields in the schema")) @@ -2050,7 +2050,7 @@ abstract class CSVSuite } checkErrorMatchPVals( exception = exceptionForLongSchema, - errorClass = "FAILED_READ_FILE.NO_HINT", + condition = "FAILED_READ_FILE.NO_HINT", parameters = Map("path" -> s".*${path.getCanonicalPath}.*")) assert(exceptionForLongSchema.getCause.getMessage.contains( "Header length: 2, schema size: 3")) @@ -2067,7 +2067,7 @@ abstract class CSVSuite } checkErrorMatchPVals( exception = caseSensitiveException, - errorClass = "FAILED_READ_FILE.NO_HINT", + condition = "FAILED_READ_FILE.NO_HINT", parameters = Map("path" -> s".*${path.getCanonicalPath}.*")) assert(caseSensitiveException.getCause.getMessage.contains( "CSV header does not conform to the schema")) @@ -2122,7 +2122,7 @@ abstract class CSVSuite exception = intercept[SparkIllegalArgumentException] { spark.read.schema(ischema).option("header", true).option("enforceSchema", false).csv(ds) }, - errorClass = "_LEGACY_ERROR_TEMP_3241", + condition = "_LEGACY_ERROR_TEMP_3241", parameters = Map("msg" -> """CSV header does not conform to the schema. | Header: columnA, columnB @@ -2161,7 +2161,7 @@ abstract class CSVSuite .schema(schema) .csv(Seq("col1,col2", "1.0,a").toDS()) }, - errorClass = "_LEGACY_ERROR_TEMP_3241", + condition = "_LEGACY_ERROR_TEMP_3241", parameters = Map("msg" -> """CSV header does not conform to the schema. | Header: col1, col2 @@ -2790,7 +2790,7 @@ abstract class CSVSuite exception = intercept[SparkUpgradeException] { csv.collect() }, - errorClass = "INCONSISTENT_BEHAVIOR_CROSS_VERSION.PARSE_DATETIME_BY_NEW_PARSER", + condition = "INCONSISTENT_BEHAVIOR_CROSS_VERSION.PARSE_DATETIME_BY_NEW_PARSER", parameters = Map( "datetime" -> "'2020-01-27T20:06:11.847-08000'", "config" -> "\"spark.sql.legacy.timeParserPolicy\"")) @@ -2850,7 +2850,7 @@ abstract class CSVSuite exception = intercept[AnalysisException] { readback.filter($"AAA" === 2 && $"bbb" === 3).collect() }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`AAA`", "proposal" -> "`BBB`, `aaa`"), context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) @@ -3458,11 +3458,11 @@ class CSVv1Suite extends CSVSuite { } checkErrorMatchPVals( exception = ex, - errorClass = "FAILED_READ_FILE.NO_HINT", + condition = "FAILED_READ_FILE.NO_HINT", parameters = Map("path" -> s".*$carsFile")) checkError( exception = ex.getCause.asInstanceOf[SparkException], - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map( "badRecord" -> "[2015,Chevy,Volt,null,null]", "failFastMode" -> "FAILFAST") @@ -3487,7 +3487,7 @@ class CSVv2Suite extends CSVSuite { .options(Map("header" -> "true", "mode" -> "failfast")) .load(testFile(carsFile)).collect() }, - errorClass = "FAILED_READ_FILE.NO_HINT", + condition = "FAILED_READ_FILE.NO_HINT", parameters = Map("path" -> s".*$carsFile"), matchPVals = true ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala index 486532028de9c..d3723881bfa24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala @@ -51,7 +51,7 @@ class JdbcUtilsSuite extends SparkFunSuite { } checkError( exception = duplicate, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`c1`")) // Throw ParseException @@ -59,14 +59,14 @@ class JdbcUtilsSuite extends SparkFunSuite { exception = intercept[ParseException]{ JdbcUtils.getCustomSchema(tableSchema, "c3 DATEE, C2 STRING", caseInsensitive) }, - errorClass = "UNSUPPORTED_DATATYPE", + condition = "UNSUPPORTED_DATATYPE", parameters = Map("typeName" -> "\"DATEE\"")) checkError( exception = intercept[ParseException]{ JdbcUtils.getCustomSchema(tableSchema, "c3 DATE. C2 STRING", caseInsensitive) }, - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'.'", "hint" -> "")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index c31ecbc437495..aea95f0af117a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -46,6 +46,7 @@ import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, Da import org.apache.spark.sql.execution.datasources.v2.json.JsonScanBuilder import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.test.SQLTestData.{DecimalData, TestData} import org.apache.spark.sql.types._ import org.apache.spark.sql.types.StructType.fromDDL import org.apache.spark.sql.types.TestUDT.{MyDenseVector, MyDenseVectorUDT} @@ -1071,7 +1072,7 @@ abstract class JsonSuite .option("mode", "FAILFAST") .json(corruptRecords) }, - errorClass = "_LEGACY_ERROR_TEMP_2165", + condition = "_LEGACY_ERROR_TEMP_2165", parameters = Map("failFastMode" -> "FAILFAST") ) @@ -1083,7 +1084,7 @@ abstract class JsonSuite .json(corruptRecords) .collect() }, - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map( "badRecord" -> "[null]", "failFastMode" -> "FAILFAST") @@ -1961,7 +1962,7 @@ abstract class JsonSuite } checkErrorMatchPVals( exception = e, - errorClass = "FAILED_READ_FILE.NO_HINT", + condition = "FAILED_READ_FILE.NO_HINT", parameters = Map("path" -> s".*${inputFile.getName}.*")) assert(e.getCause.isInstanceOf[EOFException]) assert(e.getCause.getMessage === "Unexpected end of input stream") @@ -1989,7 +1990,7 @@ abstract class JsonSuite exception = intercept[SparkException] { df.collect() }, - errorClass = "FAILED_READ_FILE.FILE_NOT_EXIST", + condition = "FAILED_READ_FILE.FILE_NOT_EXIST", parameters = Map("path" -> s".*$dir.*") ) } @@ -2075,7 +2076,7 @@ abstract class JsonSuite .option("mode", "FAILFAST") .json(path) }, - errorClass = "_LEGACY_ERROR_TEMP_2167", + condition = "_LEGACY_ERROR_TEMP_2167", parameters = Map("failFastMode" -> "FAILFAST", "dataType" -> "string|bigint")) val ex = intercept[SparkException] { @@ -2088,11 +2089,11 @@ abstract class JsonSuite } checkErrorMatchPVals( exception = ex, - errorClass = "FAILED_READ_FILE.NO_HINT", + condition = "FAILED_READ_FILE.NO_HINT", parameters = Map("path" -> s".*$path.*")) checkError( exception = ex.getCause.asInstanceOf[SparkException], - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map( "badRecord" -> "[null]", "failFastMode" -> "FAILFAST") @@ -2116,7 +2117,7 @@ abstract class JsonSuite .schema(schema) .json(corruptRecords) }, - errorClass = "_LEGACY_ERROR_TEMP_1097", + condition = "_LEGACY_ERROR_TEMP_1097", parameters = Map.empty ) @@ -2133,7 +2134,7 @@ abstract class JsonSuite .json(path) .collect() }, - errorClass = "_LEGACY_ERROR_TEMP_1097", + condition = "_LEGACY_ERROR_TEMP_1097", parameters = Map.empty ) } @@ -2181,7 +2182,7 @@ abstract class JsonSuite .json(Seq(lowerCasedJsonFieldValue._1).toDS()) .collect() }, - errorClass = "MALFORMED_RECORD_IN_PARSING.CANNOT_PARSE_STRING_AS_DATATYPE", + condition = "MALFORMED_RECORD_IN_PARSING.CANNOT_PARSE_STRING_AS_DATATYPE", parameters = Map( "failFastMode" -> "FAILFAST", "badRecord" -> lowerCasedJsonFieldValue._1, @@ -2209,7 +2210,7 @@ abstract class JsonSuite exception = intercept[AnalysisException] { spark.read.schema(schema).json(path).select("_corrupt_record").collect() }, - errorClass = "UNSUPPORTED_FEATURE.QUERY_ONLY_CORRUPT_RECORD_COLUMN", + condition = "UNSUPPORTED_FEATURE.QUERY_ONLY_CORRUPT_RECORD_COLUMN", parameters = Map.empty ) @@ -2377,7 +2378,7 @@ abstract class JsonSuite .json(testFile("test-data/utf16LE.json")) .count() }, - errorClass = "INVALID_PARAMETER_VALUE.CHARSET", + condition = "INVALID_PARAMETER_VALUE.CHARSET", parameters = Map( "charset" -> invalidCharset, "functionName" -> toSQLId("JSONOptionsInRead"), @@ -2411,11 +2412,11 @@ abstract class JsonSuite } checkErrorMatchPVals( exception = exception, - errorClass = "FAILED_READ_FILE.NO_HINT", + condition = "FAILED_READ_FILE.NO_HINT", parameters = Map("path" -> s".*$fileName.*")) checkError( exception = exception.getCause.asInstanceOf[SparkException], - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map("badRecord" -> "[empty row]", "failFastMode" -> "FAILFAST") ) } @@ -2473,7 +2474,7 @@ abstract class JsonSuite .json(path.getCanonicalPath) } }, - errorClass = "INVALID_PARAMETER_VALUE.CHARSET", + condition = "INVALID_PARAMETER_VALUE.CHARSET", parameters = Map( "charset" -> encoding, "functionName" -> toSQLId("JSONOptions"), @@ -2755,11 +2756,11 @@ abstract class JsonSuite val e = intercept[SparkException] { df.collect() } checkError( exception = e, - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map("badRecord" -> "[null]", "failFastMode" -> "FAILFAST")) checkError( exception = e.getCause.asInstanceOf[SparkRuntimeException], - errorClass = "EMPTY_JSON_FIELD_VALUE", + condition = "EMPTY_JSON_FIELD_VALUE", parameters = Map("dataType" -> toSQLType(dataType)) ) } @@ -2900,7 +2901,7 @@ abstract class JsonSuite exception = intercept[SparkUpgradeException] { json.collect() }, - errorClass = "INCONSISTENT_BEHAVIOR_CROSS_VERSION.PARSE_DATETIME_BY_NEW_PARSER", + condition = "INCONSISTENT_BEHAVIOR_CROSS_VERSION.PARSE_DATETIME_BY_NEW_PARSER", parameters = Map( "datetime" -> "'2020-01-27T20:06:11.847-08000'", "config" -> "\"spark.sql.legacy.timeParserPolicy\"")) @@ -3089,7 +3090,7 @@ abstract class JsonSuite } checkErrorMatchPVals( exception = err, - errorClass = "TASK_WRITE_FAILED", + condition = "TASK_WRITE_FAILED", parameters = Map("path" -> s".*${path.getName}.*")) val msg = err.getCause.getMessage @@ -3196,7 +3197,7 @@ abstract class JsonSuite exception = intercept[AnalysisException] { spark.read.json(path.getCanonicalPath).collect() }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`aaa`")) } withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { @@ -3207,7 +3208,7 @@ abstract class JsonSuite exception = intercept[AnalysisException] { readback.filter($"AAA" === 0 && $"bbb" === 1).collect() }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`AAA`", "proposal" -> "`BBB`, `aaa`"), context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) @@ -3361,13 +3362,13 @@ abstract class JsonSuite checkError( exception = exception, - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map("badRecord" -> "[null]", "failFastMode" -> "FAILFAST") ) checkError( exception = ExceptionUtils.getRootCause(exception).asInstanceOf[SparkRuntimeException], - errorClass = "INVALID_JSON_ROOT_FIELD", + condition = "INVALID_JSON_ROOT_FIELD", parameters = Map.empty ) @@ -3851,7 +3852,7 @@ abstract class JsonSuite exception = intercept[AnalysisException]( spark.read.schema(jsonDataSchema).json(Seq(jsonData).toDS()).collect() ), - errorClass = "INVALID_JSON_SCHEMA_MAP_TYPE", + condition = "INVALID_JSON_SCHEMA_MAP_TYPE", parameters = Map("jsonSchema" -> toSQLType(jsonDataSchema))) val jsonDir = new File(dir, "json").getCanonicalPath @@ -3861,7 +3862,7 @@ abstract class JsonSuite exception = intercept[AnalysisException]( spark.read.schema(jsonDirSchema).json(jsonDir).collect() ), - errorClass = "INVALID_JSON_SCHEMA_MAP_TYPE", + condition = "INVALID_JSON_SCHEMA_MAP_TYPE", parameters = Map("jsonSchema" -> toSQLType(jsonDirSchema))) } } @@ -3968,6 +3969,34 @@ abstract class JsonSuite ) } } + + test("SPARK-48965: Dataset#toJSON should use correct schema #1: decimals") { + val numString = "123.456" + val bd = BigDecimal(numString) + val ds1 = sql(s"select ${numString}bd as a, ${numString}bd as b").as[DecimalData] + checkDataset( + ds1, + DecimalData(bd, bd) + ) + val ds2 = ds1.toJSON + checkDataset( + ds2, + "{\"a\":123.456000000000000000,\"b\":123.456000000000000000}" + ) + } + + test("SPARK-48965: Dataset#toJSON should use correct schema #2: misaligned columns") { + val ds1 = sql("select 'Hey there' as value, 90000001 as key").as[TestData] + checkDataset( + ds1, + TestData(90000001, "Hey there") + ) + val ds2 = ds1.toJSON + checkDataset( + ds2, + "{\"key\":90000001,\"value\":\"Hey there\"}" + ) + } } class JsonV1Suite extends JsonSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index f1431f2a81b8e..f13d66b76838f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -683,7 +683,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(s"select a from $tableName where a < 0").collect() }, - errorClass = "AMBIGUOUS_REFERENCE", + condition = "AMBIGUOUS_REFERENCE", parameters = Map( "name" -> "`a`", "referenceNames" -> ("[`spark_catalog`.`default`.`spark_32622`.`a`, " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index 536db3dfe74b5..2e6413d998d12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -594,7 +594,7 @@ abstract class OrcQueryTest extends OrcTest { exception = intercept[AnalysisException] { testAllCorruptFiles() }, - errorClass = "UNABLE_TO_INFER_SCHEMA", + condition = "UNABLE_TO_INFER_SCHEMA", parameters = Map("format" -> "ORC") ) testAllCorruptFilesWithoutSchemaInfer() @@ -619,7 +619,7 @@ abstract class OrcQueryTest extends OrcTest { exception = intercept[SparkException] { testAllCorruptFiles() }, - errorClass = "FAILED_READ_FILE.CANNOT_READ_FILE_FOOTER", + condition = "FAILED_READ_FILE.CANNOT_READ_FILE_FOOTER", parameters = Map("path" -> "file:.*") ) val e4 = intercept[SparkException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 5d247c76b70be..9348d10711b35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -235,7 +235,7 @@ abstract class OrcSuite exception = intercept[SparkException] { testMergeSchemasInParallel(false, schemaReader) }.getCause.getCause.asInstanceOf[SparkException], - errorClass = "FAILED_READ_FILE.CANNOT_READ_FILE_FOOTER", + condition = "FAILED_READ_FILE.CANNOT_READ_FILE_FOOTER", parameters = Map("path" -> "file:.*") ) } @@ -481,7 +481,7 @@ abstract class OrcSuite exception = intercept[SparkException] { spark.read.orc(basePath).columns.length }.getCause.getCause.asInstanceOf[SparkException], - errorClass = "FAILED_READ_FILE.CANNOT_READ_FILE_FOOTER", + condition = "FAILED_READ_FILE.CANNOT_READ_FILE_FOOTER", parameters = Map("path" -> "file:.*") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index 48b4f8d4bc015..b8669ee4d1ef1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -63,7 +63,7 @@ trait OrcTest extends QueryTest with FileBasedDataSourceTest with BeforeAndAfter protected override def beforeAll(): Unit = { super.beforeAll() - originalConfORCImplementation = spark.conf.get(ORC_IMPLEMENTATION) + originalConfORCImplementation = spark.sessionState.conf.getConf(ORC_IMPLEMENTATION) spark.conf.set(ORC_IMPLEMENTATION.key, orcImp) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala index 28644720d0436..359436ca23636 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala @@ -125,7 +125,7 @@ class ParquetCompressionCodecPrecedenceSuite extends ParquetTest with SharedSpar exception = intercept[SparkIllegalArgumentException] { checkCompressionCodec("aa", isPartitioned) }, - errorClass = "CODEC_NOT_AVAILABLE.WITH_AVAILABLE_CODECS_SUGGESTION", + condition = "CODEC_NOT_AVAILABLE.WITH_AVAILABLE_CODECS_SUGGESTION", parameters = Map( "codecName" -> "aa", "availableCodecs" -> ("brotli, uncompressed, lzo, snappy, " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala index b0995477030c9..ee283386b8eff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala @@ -65,7 +65,7 @@ abstract class ParquetFileFormatSuite exception = intercept[SparkException] { testReadFooters(false) }.getCause.asInstanceOf[SparkException], - errorClass = "FAILED_READ_FILE.CANNOT_READ_FILE_FOOTER", + condition = "FAILED_READ_FILE.CANNOT_READ_FILE_FOOTER", parameters = Map("path" -> "file:.*") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala index 5e59418f8f928..be8d41c75bfde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala @@ -130,7 +130,7 @@ class ParquetFileMetadataStructRowIndexSuite extends QueryTest with SharedSparkS exception = intercept[AnalysisException] { df.select("*", s"${FileFormat.METADATA_NAME}.${ROW_INDEX}") }, - errorClass = "FIELD_NOT_FOUND", + condition = "FIELD_NOT_FOUND", parameters = Map( "fieldName" -> "`row_index`", "fields" -> ("`file_path`, `file_name`, `file_size`, " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 02e1c70cc8cb7..0afa545595c77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -846,7 +846,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession def checkCompressionCodec(codec: ParquetCompressionCodec): Unit = { withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> codec.name()) { withParquetFile(data) { path => - assertResult(spark.conf.get(SQLConf.PARQUET_COMPRESSION).toUpperCase(Locale.ROOT)) { + assertResult(spark.conf.get(SQLConf.PARQUET_COMPRESSION.key).toUpperCase(Locale.ROOT)) { compressionCodecFor(path, codec.name()) } } @@ -855,7 +855,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession // Checks default compression codec checkCompressionCodec( - ParquetCompressionCodec.fromString(spark.conf.get(SQLConf.PARQUET_COMPRESSION))) + ParquetCompressionCodec.fromString(spark.conf.get(SQLConf.PARQUET_COMPRESSION.key))) ParquetCompressionCodec.availableCodecs.asScala.foreach(checkCompressionCodec(_)) } @@ -1068,7 +1068,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession exception = intercept[SparkException] { spark.read.schema(readSchema).parquet(path).collect() }, - errorClass = "FAILED_READ_FILE.PARQUET_COLUMN_DATA_TYPE_MISMATCH", + condition = "FAILED_READ_FILE.PARQUET_COLUMN_DATA_TYPE_MISMATCH", parameters = Map( "path" -> ".*", "column" -> "\\[_1\\]", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 1484511a98b63..52d67a0954325 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -968,7 +968,7 @@ abstract class ParquetPartitionDiscoverySuite PartitionValues(Seq("b"), Seq(TypedPartValue("1", IntegerType)))) ) ), - errorClass = "CONFLICTING_PARTITION_COLUMN_NAMES", + condition = "CONFLICTING_PARTITION_COLUMN_NAMES", parameters = Map( "distinctPartColLists" -> "\n\tPartition column name list #0: a\n\tPartition column name list #1: b\n", @@ -985,7 +985,7 @@ abstract class ParquetPartitionDiscoverySuite PartitionValues(Seq("a"), Seq(TypedPartValue("1", IntegerType)))) ) ), - errorClass = "CONFLICTING_PARTITION_COLUMN_NAMES", + condition = "CONFLICTING_PARTITION_COLUMN_NAMES", parameters = Map( "distinctPartColLists" -> "\n\tPartition column name list #0: a\n", @@ -1003,7 +1003,7 @@ abstract class ParquetPartitionDiscoverySuite Seq(TypedPartValue("1", IntegerType), TypedPartValue("foo", StringType)))) ) ), - errorClass = "CONFLICTING_PARTITION_COLUMN_NAMES", + condition = "CONFLICTING_PARTITION_COLUMN_NAMES", parameters = Map( "distinctPartColLists" -> "\n\tPartition column name list #0: a\n\tPartition column name list #1: a, b\n", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 074781da830fe..0acb21f3e6fba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -989,7 +989,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { exception = intercept[SparkException] { spark.read.option("mergeSchema", "true").parquet(path) }, - errorClass = "CANNOT_MERGE_SCHEMAS", + condition = "CANNOT_MERGE_SCHEMAS", sqlState = "42KD9", parameters = Map( "left" -> toSQLType(df1.schema), @@ -1056,7 +1056,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { if (col(0).dataType == StringType) { checkErrorMatchPVals( exception = e, - errorClass = "FAILED_READ_FILE.PARQUET_COLUMN_DATA_TYPE_MISMATCH", + condition = "FAILED_READ_FILE.PARQUET_COLUMN_DATA_TYPE_MISMATCH", parameters = Map( "path" -> s".*${dir.getCanonicalPath}.*", "column" -> "\\[a\\]", @@ -1067,7 +1067,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { } else { checkErrorMatchPVals( exception = e, - errorClass = "FAILED_READ_FILE.PARQUET_COLUMN_DATA_TYPE_MISMATCH", + condition = "FAILED_READ_FILE.PARQUET_COLUMN_DATA_TYPE_MISMATCH", parameters = Map( "path" -> s".*${dir.getCanonicalPath}.*", "column" -> "\\[a\\]", @@ -1115,7 +1115,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { exception = intercept[AnalysisException] { spark.read.parquet(testDataPath).collect() }, - errorClass = "PARQUET_TYPE_ILLEGAL", + condition = "PARQUET_TYPE_ILLEGAL", parameters = Map("parquetType" -> "INT64 (TIMESTAMP(NANOS,true))") ) } @@ -1126,7 +1126,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { exception = intercept[AnalysisException] { spark.read.parquet(testDataPath).collect() }, - errorClass = "PARQUET_TYPE_NOT_SUPPORTED", + condition = "PARQUET_TYPE_NOT_SUPPORTED", parameters = Map("parquetType" -> "FIXED_LEN_BYTE_ARRAY (INTERVAL)") ) } @@ -1139,7 +1139,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { exception = intercept[AnalysisException] { spark.read.parquet(testDataPath).collect() }, - errorClass = "PARQUET_TYPE_NOT_RECOGNIZED", + condition = "PARQUET_TYPE_NOT_RECOGNIZED", parameters = Map("field" -> expectedParameter) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 25aa6def052b8..5c373a2de9738 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -111,7 +111,7 @@ abstract class TextSuite extends QueryTest with SharedSparkSession with CommonFi testDf.write.option("compression", "illegal").mode( SaveMode.Overwrite).text(dir.getAbsolutePath) }, - errorClass = "CODEC_NOT_AVAILABLE.WITH_AVAILABLE_CODECS_SUGGESTION", + condition = "CODEC_NOT_AVAILABLE.WITH_AVAILABLE_CODECS_SUGGESTION", parameters = Map( "codecName" -> "illegal", "availableCodecs" -> "bzip2, deflate, uncompressed, snappy, none, lz4, gzip") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/FileTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/FileTableSuite.scala index 4160516deece5..0316f09e42ce3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/FileTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/FileTableSuite.scala @@ -23,8 +23,15 @@ import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.{QueryTest, SparkSession} import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.execution.datasources.v2.csv.CSVScanBuilder +import org.apache.spark.sql.execution.datasources.v2.json.JsonScanBuilder +import org.apache.spark.sql.execution.datasources.v2.orc.OrcScanBuilder +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScanBuilder +import org.apache.spark.sql.execution.datasources.v2.text.TextScanBuilder +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -53,6 +60,8 @@ class DummyFileTable( class FileTableSuite extends QueryTest with SharedSparkSession { + private val allFileBasedDataSources = Seq("orc", "parquet", "csv", "json", "text") + test("Data type validation should check data schema only") { withTempPath { dir => val df = spark.createDataFrame(Seq(("a", 1), ("b", 2))).toDF("v", "p") @@ -85,4 +94,38 @@ class FileTableSuite extends QueryTest with SharedSparkSession { assert(table.dataSchema == expectedDataSchema) } } + + allFileBasedDataSources.foreach { format => + test(s"SPARK-49519: Merge options of table and relation when constructing FileScanBuilder" + + s" - $format") { + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + val userSpecifiedSchema = StructType(Seq(StructField("c1", StringType))) + + DataSource.lookupDataSourceV2(format, spark.sessionState.conf) match { + case Some(provider) => + val dsOptions = new CaseInsensitiveStringMap( + Map("k1" -> "v1", "k2" -> "ds_v2").asJava) + val table = provider.getTable( + userSpecifiedSchema, + Array.empty, + dsOptions.asCaseSensitiveMap()) + val tableOptions = new CaseInsensitiveStringMap( + Map("k2" -> "table_v2", "k3" -> "v3").asJava) + val mergedOptions = table.asInstanceOf[FileTable].newScanBuilder(tableOptions) match { + case csv: CSVScanBuilder => csv.options + case json: JsonScanBuilder => json.options + case orc: OrcScanBuilder => orc.options + case parquet: ParquetScanBuilder => parquet.options + case text: TextScanBuilder => text.options + } + assert(mergedOptions.size() == 3) + assert("v1".equals(mergedOptions.get("k1"))) + assert("table_v2".equals(mergedOptions.get("k2"))) + assert("v3".equals(mergedOptions.get("k3"))) + case _ => + throw new IllegalArgumentException(s"Failed to get table provider for $format") + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala index 50988e133005a..c88f51a6b7d06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala @@ -1173,7 +1173,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { exception = intercept[SparkUnsupportedOperationException] { catalog.alterNamespace(testNs, NamespaceChange.removeProperty(p)) }, - errorClass = "_LEGACY_ERROR_TEMP_2069", + condition = "_LEGACY_ERROR_TEMP_2069", parameters = Map("property" -> p)) } @@ -1184,7 +1184,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { val testIdent: IdentifierHelper = Identifier.of(Array("a", "b"), "c") checkError( exception = intercept[AnalysisException](testIdent.asTableIdentifier), - errorClass = "IDENTIFIER_TOO_MANY_NAME_PARTS", + condition = "IDENTIFIER_TOO_MANY_NAME_PARTS", parameters = Map("identifier" -> "`a`.`b`.`c`") ) } @@ -1193,7 +1193,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { val testIdent: MultipartIdentifierHelper = Seq("a", "b", "c") checkError( exception = intercept[AnalysisException](testIdent.asFunctionIdentifier), - errorClass = "IDENTIFIER_TOO_MANY_NAME_PARTS", + condition = "IDENTIFIER_TOO_MANY_NAME_PARTS", parameters = Map("identifier" -> "`a`.`b`.`c`") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/DerbyTableCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/DerbyTableCatalogSuite.scala index d793ef526c47b..6125777c7a426 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/DerbyTableCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/DerbyTableCatalogSuite.scala @@ -45,7 +45,7 @@ class DerbyTableCatalogSuite extends QueryTest with SharedSparkSession { checkError( exception = intercept[SparkUnsupportedOperationException]( sql(s"ALTER TABLE $n1t1 RENAME TO $n2t2")), - errorClass = "CANNOT_RENAME_ACROSS_SCHEMA", + condition = "CANNOT_RENAME_ACROSS_SCHEMA", parameters = Map("type" -> "table")) sql(s"ALTER TABLE $n1t1 RENAME TO $n1t2") checkAnswer(sql(s"SHOW TABLES IN derby.test1"), Row("test1", "TABLE2", false)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalogSuite.scala index 8e5fba3607b93..580034ff7b0e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalogSuite.scala @@ -176,7 +176,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { sql("CREATE TABLE h2.bad_test.new_table(i INT, j STRING)") } checkError(exp, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> "`bad_test`")) } @@ -200,7 +200,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $tableName ADD COLUMNS (c3 DOUBLE)") }, - errorClass = "FIELD_ALREADY_EXISTS", + condition = "FIELD_ALREADY_EXISTS", parameters = Map( "op" -> "add", "fieldNames" -> "`c3`", @@ -239,7 +239,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $tableName RENAME COLUMN C TO C0") }, - errorClass = "FIELD_ALREADY_EXISTS", + condition = "FIELD_ALREADY_EXISTS", parameters = Map( "op" -> "rename", "fieldNames" -> "`C0`", @@ -279,7 +279,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`bad_column`", @@ -316,7 +316,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`bad_column`", @@ -327,7 +327,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[ParseException] { sql(s"ALTER TABLE $tableName ALTER COLUMN id TYPE bad_type") }, - errorClass = "UNSUPPORTED_DATATYPE", + condition = "UNSUPPORTED_DATATYPE", parameters = Map("typeName" -> "\"BAD_TYPE\""), context = ExpectedContext("bad_type", 51, 58)) } @@ -361,7 +361,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`bad_column`", @@ -393,7 +393,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("CREATE NAMESPACE h2.test_namespace LOCATION './samplepath'") }, - errorClass = "NOT_SUPPORTED_IN_JDBC_CATALOG.COMMAND", + condition = "NOT_SUPPORTED_IN_JDBC_CATALOG.COMMAND", sqlState = "0A000", parameters = Map("cmd" -> toSQLStmt("CREATE NAMESPACE ... LOCATION ..."))) } @@ -416,7 +416,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(s"ALTER NAMESPACE h2.test_namespace SET LOCATION '/tmp/loc_test_2'") }, - errorClass = "NOT_SUPPORTED_IN_JDBC_CATALOG.COMMAND_WITH_PROPERTY", + condition = "NOT_SUPPORTED_IN_JDBC_CATALOG.COMMAND_WITH_PROPERTY", sqlState = "0A000", parameters = Map( "cmd" -> toSQLStmt("SET NAMESPACE"), @@ -426,7 +426,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(s"ALTER NAMESPACE h2.test_namespace SET PROPERTIES('a'='b')") }, - errorClass = "NOT_SUPPORTED_IN_JDBC_CATALOG.COMMAND_WITH_PROPERTY", + condition = "NOT_SUPPORTED_IN_JDBC_CATALOG.COMMAND_WITH_PROPERTY", sqlState = "0A000", parameters = Map( "cmd" -> toSQLStmt("SET NAMESPACE"), @@ -444,7 +444,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(s"ALTER TABLE $tableName ALTER COLUMN ID COMMENT 'test'") }, - errorClass = "_LEGACY_ERROR_TEMP_1305", + condition = "_LEGACY_ERROR_TEMP_1305", parameters = Map("change" -> "org.apache.spark.sql.connector.catalog.TableChange\\$UpdateColumnComment.*"), matchPVals = true) @@ -454,7 +454,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`bad_column`", @@ -490,7 +490,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`C2`", @@ -513,7 +513,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`C3`", @@ -535,7 +535,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`C1`", @@ -557,7 +557,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`C1`", @@ -596,7 +596,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { sql("CREATE TABLE h2.test.new_table(i INT, j STRING)" + " TBLPROPERTIES('ENGINE'='tableEngineName')") }, - errorClass = "FAILED_JDBC.CREATE_TABLE", + condition = "FAILED_JDBC.CREATE_TABLE", parameters = Map( "url" -> "jdbc:.*", "tableName" -> "`test`.`new_table`")) @@ -615,7 +615,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException]{ sql("CREATE TABLE h2.test.new_table(c CHAR(1000000001))") }, - errorClass = "FAILED_JDBC.CREATE_TABLE", + condition = "FAILED_JDBC.CREATE_TABLE", parameters = Map( "url" -> "jdbc:.*", "tableName" -> "`test`.`new_table`")) @@ -626,7 +626,7 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkIllegalArgumentException]( sql("CREATE TABLE h2.test.new_table(c array)") ), - errorClass = "_LEGACY_ERROR_TEMP_2082", + condition = "_LEGACY_ERROR_TEMP_2082", parameters = Map("catalogString" -> "array") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala index 2858d356d4c9a..4833b8630134c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala @@ -58,7 +58,7 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB override def beforeAll(): Unit = { super.beforeAll() - spark.conf.set(SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED, false) + spark.conf.set(SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key, false) spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key, newStateStoreProvider().getClass.getName) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index 97c88037a7171..af07707569500 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -942,7 +942,7 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass // skip version and operator ID to test out functionalities .load() - val numShufflePartitions = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS) + val numShufflePartitions = sqlConf.getConf(SQLConf.SHUFFLE_PARTITIONS) val resultDf = stateReadDf .selectExpr("key.value AS key_value", "value.count AS value_count", "partition_id") @@ -966,7 +966,7 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass } test("partition_id column with stream-stream join") { - val numShufflePartitions = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS) + val numShufflePartitions = sqlConf.getConf(SQLConf.SHUFFLE_PARTITIONS) withTempDir { tempDir => runStreamStreamJoinQueryWithOneThousandInputs(tempDir.getAbsolutePath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala index ccd4e005756ad..61091fde35e79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala @@ -21,8 +21,10 @@ import java.time.Duration import org.apache.spark.sql.{Encoders, Row} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider, TestClass} +import org.apache.spark.sql.functions.explode import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{ExpiredTimerInfo, OutputMode, RunningCountStatefulProcessor, StatefulProcessor, StateStoreMetricsTest, TimeMode, TimerValues, TransformWithStateSuiteUtils, TTLConfig, ValueState} +import org.apache.spark.sql.streaming.{ExpiredTimerInfo, InputMapRow, ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor, OutputMode, RunningCountStatefulProcessor, StatefulProcessor, StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState} +import org.apache.spark.sql.streaming.util.StreamManualClock /** Stateful processor of single value state var with non-primitive type */ class StatefulProcessorWithSingleValueVar extends RunningCountStatefulProcessor { @@ -73,6 +75,52 @@ class StatefulProcessorWithTTL } } +/** Stateful processor tracking groups belonging to sessions with/without TTL */ +class SessionGroupsStatefulProcessor extends + StatefulProcessor[String, (String, String), String] { + @transient private var _groupsList: ListState[String] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + _groupsList = getHandle.getListState("groupsList", Encoders.STRING) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, String)], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[String] = { + inputRows.foreach { inputRow => + _groupsList.appendValue(inputRow._2) + } + Iterator.empty + } +} + +class SessionGroupsStatefulProcessorWithTTL extends + StatefulProcessor[String, (String, String), String] { + @transient private var _groupsListWithTTL: ListState[String] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + _groupsListWithTTL = getHandle.getListState("groupsListWithTTL", Encoders.STRING, + TTLConfig(Duration.ofMillis(30000))) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, String)], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[String] = { + inputRows.foreach { inputRow => + _groupsListWithTTL.appendValue(inputRow._2) + } + Iterator.empty + } +} + /** * Test suite to verify integration of state data source reader with the transformWithState operator */ @@ -111,7 +159,7 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest val resultDf = stateReaderDf.selectExpr( "key.value AS groupingKey", - "value.id AS valueId", "value.name AS valueName", + "single_value.id AS valueId", "single_value.name AS valueName", "partition_id") checkAnswer(resultDf, @@ -174,7 +222,7 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest .load() val resultDf = stateReaderDf.selectExpr( - "key.value", "value.value", "expiration_timestamp", "partition_id") + "key.value", "single_value.value", "single_value.ttlExpirationMs", "partition_id") var count = 0L resultDf.collect().foreach { row => @@ -187,7 +235,7 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest val answerDf = stateReaderDf.selectExpr( "key.value AS groupingKey", - "value.value AS valueId", "partition_id") + "single_value.value.value AS valueId", "partition_id") checkAnswer(answerDf, Seq(Row("a", 1L, 0), Row("b", 1L, 1))) @@ -217,4 +265,220 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest } } } + + test("state data source integration - list state") { + withTempDir { tempDir => + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + + val inputData = MemoryStream[(String, String)] + val result = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new SessionGroupsStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, ("session1", "group2")), + AddData(inputData, ("session1", "group1")), + AddData(inputData, ("session2", "group1")), + CheckNewAnswer(), + AddData(inputData, ("session3", "group7")), + AddData(inputData, ("session1", "group4")), + CheckNewAnswer(), + StopStream + ) + + val stateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "groupsList") + .load() + + val listStateDf = stateReaderDf + .selectExpr( + "key.value AS groupingKey", + "list_value.value AS valueList", + "partition_id") + .select($"groupingKey", + explode($"valueList")) + + checkAnswer(listStateDf, + Seq(Row("session1", "group1"), Row("session1", "group2"), Row("session1", "group4"), + Row("session2", "group1"), Row("session3", "group7"))) + } + } + } + + test("state data source integration - list state and TTL") { + withTempDir { tempDir => + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + val inputData = MemoryStream[(String, String)] + val result = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new SessionGroupsStatefulProcessorWithTTL(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, ("session1", "group2")), + AddData(inputData, ("session1", "group1")), + AddData(inputData, ("session2", "group1")), + AddData(inputData, ("session3", "group7")), + AddData(inputData, ("session1", "group4")), + Execute { _ => + // wait for the batch to run since we are using processing time + Thread.sleep(5000) + }, + StopStream + ) + + val stateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "groupsListWithTTL") + .load() + + val listStateDf = stateReaderDf + .selectExpr( + "key.value AS groupingKey", + "list_value AS valueList", + "partition_id") + .select($"groupingKey", + explode($"valueList").as("valueList")) + + val resultDf = listStateDf.selectExpr("valueList.ttlExpirationMs") + var count = 0L + resultDf.collect().foreach { row => + count = count + 1 + assert(row.getLong(0) > 0) + } + + // verify that 5 state rows are present + assert(count === 5) + + val valuesDf = listStateDf.selectExpr("groupingKey", + "valueList.value.value AS groupId") + + checkAnswer(valuesDf, + Seq(Row("session1", "group1"), Row("session1", "group2"), Row("session1", "group4"), + Row("session2", "group1"), Row("session3", "group7"))) + } + } + } + + test("state data source integration - map state with single variable") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { tempDir => + val inputData = MemoryStream[InputMapRow] + val result = inputData.toDS() + .groupByKey(x => x.key) + .transformWithState(new TestMapStateProcessor(), + TimeMode.None(), + OutputMode.Append()) + testStream(result, OutputMode.Append())( + StartStream(checkpointLocation = tempDir.getCanonicalPath), + AddData(inputData, InputMapRow("k1", "updateValue", ("v1", "10"))), + AddData(inputData, InputMapRow("k1", "exists", ("", ""))), + AddData(inputData, InputMapRow("k2", "exists", ("", ""))), + CheckNewAnswer(("k1", "exists", "true"), ("k2", "exists", "false")), + + AddData(inputData, InputMapRow("k1", "updateValue", ("v2", "5"))), + AddData(inputData, InputMapRow("k2", "updateValue", ("v2", "3"))), + ProcessAllAvailable(), + StopStream + ) + + val stateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "sessionState") + .load() + + val resultDf = stateReaderDf.selectExpr( + "key.value AS groupingKey", "map_value AS mapValue") + + checkAnswer(resultDf, + Seq( + Row("k1", + Map(Row("v1") -> Row("10"), Row("v2") -> Row("5"))), + Row("k2", + Map(Row("v2") -> Row("3")))) + ) + } + } + } + + test("state data source integration - map state TTL with single variable") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { tempDir => + val inputStream = MemoryStream[MapInputEvent] + val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1)) + val result = inputStream.toDS() + .groupByKey(x => x.key) + .transformWithState( + new MapStateTTLProcessor(ttlConfig), + TimeMode.ProcessingTime(), + OutputMode.Append()) + + val clock = new StreamManualClock + testStream(result)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock, + checkpointLocation = tempDir.getCanonicalPath), + AddData(inputStream, + MapInputEvent("k1", "key1", "put", 1), + MapInputEvent("k1", "key2", "put", 2) + ), + AdvanceManualClock(1 * 1000), // batch timestamp: 1000 + CheckNewAnswer(), + AddData(inputStream, + MapInputEvent("k1", "key1", "get", -1), + MapInputEvent("k1", "key2", "get", -1) + ), + AdvanceManualClock(30 * 1000), // batch timestamp: 31000 + CheckNewAnswer( + MapOutputEvent("k1", "key1", 1, isTTLValue = false, -1), + MapOutputEvent("k1", "key2", 2, isTTLValue = false, -1) + ), + // get values from ttl state + AddData(inputStream, + MapInputEvent("k1", "", "get_values_in_ttl_state", -1) + ), + AdvanceManualClock(1 * 1000), // batch timestamp: 32000 + CheckNewAnswer( + MapOutputEvent("k1", "key1", -1, isTTLValue = true, 61000), + MapOutputEvent("k1", "key2", -1, isTTLValue = true, 61000) + ), + StopStream + ) + + val stateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "mapState") + .load() + + val resultDf = stateReaderDf.selectExpr( + "key.value AS groupingKey", "map_value AS mapValue") + + checkAnswer(resultDf, + Seq( + Row("k1", + Map(Row("key2") -> Row(Row(2), 61000L), + Row("key1") -> Row(Row(1), 61000L)))) + ) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index ebead0b663486..91f21c4a2ed34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -255,7 +255,7 @@ class XmlSuite .option("mode", FailFastMode.name) .xml(inputFile) }, - errorClass = "_LEGACY_ERROR_TEMP_2165", + condition = "_LEGACY_ERROR_TEMP_2165", parameters = Map("failFastMode" -> "FAILFAST") ) val exceptionInParsing = intercept[SparkException] { @@ -268,11 +268,11 @@ class XmlSuite } checkErrorMatchPVals( exception = exceptionInParsing, - errorClass = "FAILED_READ_FILE.NO_HINT", + condition = "FAILED_READ_FILE.NO_HINT", parameters = Map("path" -> s".*$inputFile.*")) checkError( exception = exceptionInParsing.getCause.asInstanceOf[SparkException], - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map( "badRecord" -> "[null]", "failFastMode" -> FailFastMode.name) @@ -288,7 +288,7 @@ class XmlSuite .option("mode", FailFastMode.name) .xml(inputFile) }, - errorClass = "_LEGACY_ERROR_TEMP_2165", + condition = "_LEGACY_ERROR_TEMP_2165", parameters = Map("failFastMode" -> "FAILFAST")) val exceptionInParsing = intercept[SparkException] { spark.read @@ -300,11 +300,11 @@ class XmlSuite } checkErrorMatchPVals( exception = exceptionInParsing, - errorClass = "FAILED_READ_FILE.NO_HINT", + condition = "FAILED_READ_FILE.NO_HINT", parameters = Map("path" -> s".*$inputFile.*")) checkError( exception = exceptionInParsing.getCause.asInstanceOf[SparkException], - errorClass = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", parameters = Map( "badRecord" -> "[null]", "failFastMode" -> FailFastMode.name) @@ -1315,7 +1315,7 @@ class XmlSuite spark.sql(s"""SELECT schema_of_xml('1', map('mode', 'DROPMALFORMED'))""") .collect() }, - errorClass = "_LEGACY_ERROR_TEMP_1099", + condition = "_LEGACY_ERROR_TEMP_1099", parameters = Map( "funcName" -> "schema_of_xml", "mode" -> "DROPMALFORMED", @@ -1330,7 +1330,7 @@ class XmlSuite spark.sql(s"""SELECT schema_of_xml('1', map('mode', 'FAILFAST'))""") .collect() }, - errorClass = "_LEGACY_ERROR_TEMP_2165", + condition = "_LEGACY_ERROR_TEMP_2165", parameters = Map( "failFastMode" -> FailFastMode.name) ) @@ -1394,7 +1394,7 @@ class XmlSuite exception = intercept[AnalysisException] { df.select(to_xml($"value")).collect() }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"to_xml(value)\"", "paramIndex" -> ordinalNumber(0), @@ -1777,14 +1777,14 @@ class XmlSuite exception = intercept[AnalysisException] { spark.read.xml("/this/file/does/not/exist") }, - errorClass = "PATH_NOT_FOUND", + condition = "PATH_NOT_FOUND", parameters = Map("path" -> "file:/this/file/does/not/exist") ) checkError( exception = intercept[AnalysisException] { spark.read.schema(buildSchema(field("dummy"))).xml("/this/file/does/not/exist") }, - errorClass = "PATH_NOT_FOUND", + condition = "PATH_NOT_FOUND", parameters = Map("path" -> "file:/this/file/does/not/exist") ) } @@ -2498,7 +2498,7 @@ class XmlSuite } checkErrorMatchPVals( exception = err, - errorClass = "TASK_WRITE_FAILED", + condition = "TASK_WRITE_FAILED", parameters = Map("path" -> s".*${path.getName}.*")) val msg = err.getCause.getMessage assert( @@ -2923,7 +2923,7 @@ class XmlSuite exception = intercept[SparkException] { df.collect() }, - errorClass = "FAILED_READ_FILE.FILE_NOT_EXIST", + condition = "FAILED_READ_FILE.FILE_NOT_EXIST", parameters = Map("path" -> s".*$dir.*") ) } @@ -3020,7 +3020,7 @@ class XmlSuite } checkErrorMatchPVals( exception = e, - errorClass = "TASK_WRITE_FAILED", + condition = "TASK_WRITE_FAILED", parameters = Map("path" -> s".*${dir.getName}.*")) assert(e.getCause.isInstanceOf[XMLStreamException]) assert(e.getCause.getMessage.contains(errorMsg)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 079ab994736b2..e555033b53055 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -548,7 +548,7 @@ class HashedRelationSuite extends SharedSparkSession { exception = intercept[SparkException] { keyIterator.next() }, - errorClass = "_LEGACY_ERROR_TEMP_2104", + condition = "_LEGACY_ERROR_TEMP_2104", parameters = Map.empty ) assert(buffer.sortWith(_ < _) === randomArray) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index b207afeae1068..dcebece29037f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -188,7 +188,7 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase { spark.dataSource.registerPython(dataSourceName, dataSource) checkError( exception = intercept[AnalysisException](spark.read.format(dataSourceName).load()), - errorClass = "INVALID_SCHEMA.NON_STRUCT_TYPE", + condition = "INVALID_SCHEMA.NON_STRUCT_TYPE", parameters = Map("inputSchema" -> "INT", "dataType" -> "\"INT\"")) } @@ -309,7 +309,7 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase { exception = intercept[AnalysisException] { spark.dataSource.registerPython(provider, dataSource) }, - errorClass = "DATA_SOURCE_ALREADY_EXISTS", + condition = "DATA_SOURCE_ALREADY_EXISTS", parameters = Map("provider" -> provider)) } } @@ -657,7 +657,7 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase { exception = intercept[AnalysisException] { spark.range(1).write.format(dataSourceName).save() }, - errorClass = "UNSUPPORTED_DATA_SOURCE_SAVE_MODE", + condition = "UNSUPPORTED_DATA_SOURCE_SAVE_MODE", parameters = Map("source" -> "SimpleDataSource", "createMode" -> "\"ErrorIfExists\"")) } @@ -666,7 +666,7 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase { exception = intercept[AnalysisException] { spark.range(1).write.format(dataSourceName).mode("ignore").save() }, - errorClass = "UNSUPPORTED_DATA_SOURCE_SAVE_MODE", + condition = "UNSUPPORTED_DATA_SOURCE_SAVE_MODE", parameters = Map("source" -> "SimpleDataSource", "createMode" -> "\"Ignore\"")) } @@ -675,7 +675,7 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase { exception = intercept[AnalysisException] { spark.range(1).write.format(dataSourceName).mode("foo").save() }, - errorClass = "INVALID_SAVE_MODE", + condition = "INVALID_SAVE_MODE", parameters = Map("mode" -> "\"foo\"")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala index 8d62d4747198b..8d0e1c5f578fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala @@ -258,7 +258,7 @@ class PythonStreamingDataSourceSimpleSuite extends PythonDataSourceSuiteBase { } checkErrorMatchPVals( err, - errorClass = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR", + condition = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR", parameters = Map( "action" -> action, "msg" -> "(.|\\n)*" @@ -324,7 +324,7 @@ class PythonStreamingDataSourceSimpleSuite extends PythonDataSourceSuiteBase { } checkErrorMatchPVals( err, - errorClass = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR", + condition = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR", parameters = Map( "action" -> action, "msg" -> "(.|\\n)*" @@ -661,7 +661,7 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { } checkErrorMatchPVals( err, - errorClass = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR", + condition = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR", parameters = Map( "action" -> action, "msg" -> "(.|\\n)*" @@ -723,7 +723,7 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { } checkErrorMatchPVals( err, - errorClass = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR", + condition = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR", parameters = Map( "action" -> action, "msg" -> "(.|\\n)*" @@ -1035,7 +1035,7 @@ class PythonStreamingDataSourceWriteSuite extends PythonDataSourceSuiteBase { exception = intercept[AnalysisException] { runQuery("complete") }, - errorClass = "STREAMING_OUTPUT_MODE.UNSUPPORTED_OPERATION", + condition = "STREAMING_OUTPUT_MODE.UNSUPPORTED_OPERATION", sqlState = "42KDE", parameters = Map( "outputMode" -> "complete", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala index f82b544ecf120..0339f7461f0a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -96,7 +96,7 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.agg(testUdf(df("v"))).collect() }, - errorClass = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE", parameters = Map("sqlExpr" -> "\"pandas_udf(v)\"", "dataType" -> "VARIANT")) } @@ -110,7 +110,7 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.agg(testUdf(df("arr_v"))).collect() }, - errorClass = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE", parameters = Map("sqlExpr" -> "\"pandas_udf(arr_v)\"", "dataType" -> "ARRAY")) } @@ -123,7 +123,7 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.agg(testUdf(df("id"))).collect() }, - errorClass = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE", parameters = Map("sqlExpr" -> "\"pandas_udf(id)\"", "dataType" -> "VARIANT")) } @@ -136,7 +136,7 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.agg(testUdf(df("id"))).collect() }, - errorClass = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE", parameters = Map( "sqlExpr" -> "\"pandas_udf(id)\"", "dataType" -> "STRUCT>")) @@ -175,7 +175,7 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { spark.range(1).select(transform(array("id"), x => pythonTestUDF(x))).collect() }, - errorClass = "UNSUPPORTED_FEATURE.LAMBDA_FUNCTION_WITH_PYTHON_UDF", + condition = "UNSUPPORTED_FEATURE.LAMBDA_FUNCTION_WITH_PYTHON_UDF", parameters = Map("funcName" -> "\"pyUDF(namedlambdavariable())\""), context = ExpectedContext( "transform", s".*${this.getClass.getSimpleName}.*")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala index 36a9c3c40e3e6..041bd143067a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala @@ -133,7 +133,7 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { spark.sql("select udtf.* from t, lateral variantInputUDTF(v) udtf").collect() }, - errorClass = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE", parameters = Map( "sqlExpr" -> """"InputVariantUDTF\(outer\(v#\d+\)\)"""", "dataType" -> "VARIANT"), @@ -156,7 +156,7 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { spark.sql("select udtf.* from t, lateral variantInputUDTF(map_v) udtf").collect() }, - errorClass = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE", parameters = Map( "sqlExpr" -> """"InputVariantUDTF\(outer\(map_v#\d+\)\)"""", "dataType" -> "MAP"), @@ -175,7 +175,7 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { spark.sql("select * from variantOutUDTF()").collect() }, - errorClass = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE", parameters = Map( "sqlExpr" -> "\"SimpleOutputVariantUDTF()\"", "dataType" -> "VARIANT"), @@ -192,7 +192,7 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { spark.sql("select * from arrayOfVariantOutUDTF()").collect() }, - errorClass = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE", parameters = Map( "sqlExpr" -> "\"OutputArrayOfVariantUDTF()\"", "dataType" -> "ARRAY"), @@ -488,7 +488,7 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { | WITH SINGLE PARTITION | ORDER BY device_id, data_ds) |""".stripMargin)), - errorClass = "_LEGACY_ERROR_TEMP_0064", + condition = "_LEGACY_ERROR_TEMP_0064", parameters = Map("msg" -> ("The table function call includes a table argument with an invalid " + "partitioning/ordering specification: the ORDER BY clause included multiple " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala index 9e494fdddda9c..615e1e89f30b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala @@ -31,8 +31,9 @@ import org.apache.spark.sql.{Encoder, Row} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.execution.streaming.{StatefulProcessorHandleImpl, StatefulProcessorHandleState} +import org.apache.spark.sql.execution.streaming.state.StateMessage import org.apache.spark.sql.execution.streaming.state.StateMessage.{Clear, Exists, Get, HandleState, SetHandleState, StateCallCommand, StatefulProcessorCall, ValueStateCall, ValueStateUpdate} -import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.streaming.{TTLConfig, ValueState} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with BeforeAndAfterEach { @@ -67,14 +68,27 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo verify(outputStream).writeInt(0) } - test("get value state") { - val message = StatefulProcessorCall.newBuilder().setGetValueState( - StateCallCommand.newBuilder() + Seq(true, false).foreach { useTTL => + test(s"get value state, useTTL=$useTTL") { + val stateCallCommandBuilder = StateCallCommand.newBuilder() .setStateName("newName") - .setSchema("StructType(List(StructField(value,IntegerType,true)))")).build() - stateServer.handleStatefulProcessorCall(message) - verify(statefulProcessorHandle).getValueState[Row](any[String], any[Encoder[Row]]) - verify(outputStream).writeInt(0) + .setSchema("StructType(List(StructField(value,IntegerType,true)))") + if (useTTL) { + stateCallCommandBuilder.setTtl(StateMessage.TTLConfig.newBuilder().setDurationMs(1000)) + } + val message = StatefulProcessorCall + .newBuilder() + .setGetValueState(stateCallCommandBuilder.build()) + .build() + stateServer.handleStatefulProcessorCall(message) + if (useTTL) { + verify(statefulProcessorHandle) + .getValueState[Row](any[String], any[Encoder[Row]], any[TTLConfig]) + } else { + verify(statefulProcessorHandle).getValueState[Row](any[String], any[Encoder[Row]]) + } + verify(outputStream).writeInt(0) + } } test("value state exists") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 42eb9fa17a210..808ffe036f89d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -247,7 +247,7 @@ class CompactibleFileStreamLogSuite extends SharedSparkSession { exception = intercept[SparkUnsupportedOperationException] { compactibleLog.purge(2) }, - errorClass = "_LEGACY_ERROR_TEMP_2260", + condition = "_LEGACY_ERROR_TEMP_2260", parameters = Map.empty) // Below line would fail with IllegalStateException if we don't prevent purge: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSessionsIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSessionsIteratorSuite.scala index e550d8ef46085..8ed63f5680b0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSessionsIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSessionsIteratorSuite.scala @@ -197,7 +197,7 @@ class MergingSessionsIteratorSuite extends SharedSparkSession { exception = intercept[SparkException] { iterator.next() }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "Input iterator is not sorted based on session!")) // afterwards, calling either hasNext() or next() will throw IllegalStateException @@ -205,14 +205,14 @@ class MergingSessionsIteratorSuite extends SharedSparkSession { exception = intercept[SparkException] { iterator.hasNext }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "The iterator is already corrupted.")) checkError( exception = intercept[SparkException] { iterator.next() }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "The iterator is already corrupted.")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala index 88af5cfddb487..187eda5d36f61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala @@ -270,7 +270,7 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { exception = intercept[SparkException] { iterator.next() }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "The iterator must be sorted by key and session start!")) // afterwards, calling either hasNext() or next() will throw IllegalStateException @@ -278,14 +278,14 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { exception = intercept[SparkException] { iterator.hasNext }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "The iterator is already corrupted.")) checkError( exception = intercept[SparkException] { iterator.next() }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "The iterator is already corrupted.")) } @@ -312,7 +312,7 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { exception = intercept[SparkException] { iterator.next() }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "The iterator must be sorted by key and session start!")) // afterwards, calling either hasNext() or next() will throw IllegalStateException @@ -320,14 +320,14 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { exception = intercept[SparkException] { iterator.hasNext }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "The iterator is already corrupted.")) checkError( exception = intercept[SparkException] { iterator.next() }, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map("message" -> "The iterator is already corrupted.")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala index 324717d92c972..32f92ce276a06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkException +import org.apache.spark.{ExecutorDeadException, SparkException} import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions.{count, timestamp_seconds, window} @@ -128,12 +128,14 @@ class ForeachWriterSuite extends StreamTest with SharedSparkSession with BeforeA testQuietly("foreach with error") { withTempDir { checkpointDir => val input = MemoryStream[Int] + + val funcEx = new RuntimeException("ForeachSinkSuite error") val query = input.toDS().repartition(1).writeStream .option("checkpointLocation", checkpointDir.getCanonicalPath) .foreach(new TestForeachWriter() { override def process(value: Int): Unit = { super.process(value) - throw new RuntimeException("ForeachSinkSuite error") + throw funcEx } }).start() input.addData(1, 2, 3, 4) @@ -142,8 +144,13 @@ class ForeachWriterSuite extends StreamTest with SharedSparkSession with BeforeA val e = intercept[StreamingQueryException] { query.processAllAvailable() } - assert(e.getCause.isInstanceOf[SparkException]) - assert(e.getCause.getCause.getMessage === "ForeachSinkSuite error") + + val errClass = "FOREACH_USER_FUNCTION_ERROR" + + // verify that we classified the exception + assert(e.getMessage.contains(errClass)) + assert(e.cause.asInstanceOf[RuntimeException].getMessage == funcEx.getMessage) + assert(query.isActive === false) val allEvents = ForeachWriterSuite.allEvents() @@ -157,6 +164,23 @@ class ForeachWriterSuite extends StreamTest with SharedSparkSession with BeforeA assert(errorEvent.error.get.getMessage === "ForeachSinkSuite error") // 'close' shouldn't be called with abort message if close with error has been called assert(allEvents(0).size == 3) + + val sparkEx = ExecutorDeadException("network error") + val e2 = intercept[StreamingQueryException] { + val query2 = input.toDS().repartition(1).writeStream + .foreach(new TestForeachWriter() { + override def process(value: Int): Unit = { + super.process(value) + throw sparkEx + } + }).start() + query2.processAllAvailable() + } + + // we didn't wrap the spark exception + assert(!e2.getMessage.contains(errClass)) + assert(e2.getCause.getCause.asInstanceOf[ExecutorDeadException].getMessage + == sparkEx.getMessage) } } @@ -286,7 +310,7 @@ class ForeachWriterSuite extends StreamTest with SharedSparkSession with BeforeA val errorEvent = allEvents(0)(1).asInstanceOf[ForeachWriterSuite.Close] checkError( exception = errorEvent.error.get.asInstanceOf[SparkException], - errorClass = "_LEGACY_ERROR_TEMP_2256", + condition = "_LEGACY_ERROR_TEMP_2256", parameters = Map.empty ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RatePerMicroBatchProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RatePerMicroBatchProviderSuite.scala index 128b59b26b823..aa03545625ec9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RatePerMicroBatchProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RatePerMicroBatchProviderSuite.scala @@ -202,7 +202,7 @@ class RatePerMicroBatchProviderSuite extends StreamTest { .schema(spark.range(1).schema) .load() }, - errorClass = "_LEGACY_ERROR_TEMP_2242", + condition = "_LEGACY_ERROR_TEMP_2242", parameters = Map("provider" -> "RatePerMicroBatchProvider")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 0732e126a0131..aeb1bba31410d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -208,7 +208,7 @@ class RateStreamProviderSuite extends StreamTest { checkError( exception = e, - errorClass = "INCORRECT_RAMP_UP_RATE", + condition = "INCORRECT_RAMP_UP_RATE", parameters = Map( "rowsPerSecond" -> Long.MaxValue.toString, "maxSeconds" -> "1", @@ -229,7 +229,7 @@ class RateStreamProviderSuite extends StreamTest { checkError( exception = e, - errorClass = "INTERNAL_ERROR", + condition = "INTERNAL_ERROR", parameters = Map( ("message" -> ("Max offset with 100 rowsPerSecond is 92233720368547758, " + @@ -352,7 +352,7 @@ class RateStreamProviderSuite extends StreamTest { .schema(spark.range(1).schema) .load() }, - errorClass = "_LEGACY_ERROR_TEMP_2242", + condition = "_LEGACY_ERROR_TEMP_2242", parameters = Map("provider" -> "RateStreamProvider")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 87e34601dc098..2c17d75624d38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -198,7 +198,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSparkSession { exception = intercept[SparkUnsupportedOperationException] { spark.readStream.schema(userSpecifiedSchema).format("socket").options(params).load() }, - errorClass = "_LEGACY_ERROR_TEMP_2242", + condition = "_LEGACY_ERROR_TEMP_2242", parameters = Map("provider" -> "TextSocketSourceProvider")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala index ea6fd8ab312c9..2456999b4382a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.GroupStateImpl._ import org.apache.spark.sql.streaming.StreamTest @@ -201,7 +201,7 @@ class FlatMapGroupsWithStateExecHelperSuite extends StreamTest { private def newStateManager[T: Encoder](version: Int, withTimestamp: Boolean): StateManager = { FlatMapGroupsWithStateExecHelper.createStateManager( - implicitly[Encoder[T]].asInstanceOf[ExpressionEncoder[Any]], + encoderFor[T].asInstanceOf[ExpressionEncoder[Any]], withTimestamp, version) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala index 48816486cbd00..e9300464af8dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala @@ -22,7 +22,7 @@ import java.util.UUID import org.apache.spark.{SparkIllegalArgumentException, SparkUnsupportedOperationException} import org.apache.spark.sql.Encoders -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, ListStateImplWithTTL, StatefulProcessorHandleImpl} import org.apache.spark.sql.streaming.{ListState, TimeMode, TTLConfig, ValueState} @@ -38,7 +38,7 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) + stringEncoder, TimeMode.None()) val listState: ListState[Long] = handle.getListState[Long]("listState", Encoders.scalaLong) @@ -49,7 +49,7 @@ class ListStateSuite extends StateVariableSuiteBase { checkError( exception = e, - errorClass = "ILLEGAL_STATE_STORE_VALUE.NULL_VALUE", + condition = "ILLEGAL_STATE_STORE_VALUE.NULL_VALUE", sqlState = Some("42601"), parameters = Map("stateName" -> "listState") ) @@ -71,7 +71,7 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) + stringEncoder, TimeMode.None()) val testState: ListState[Long] = handle.getListState[Long]("testState", Encoders.scalaLong) ImplicitGroupingKeyTracker.setImplicitKey("test_key") @@ -99,7 +99,7 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) + stringEncoder, TimeMode.None()) val testState1: ListState[Long] = handle.getListState[Long]("testState1", Encoders.scalaLong) val testState2: ListState[Long] = handle.getListState[Long]("testState2", Encoders.scalaLong) @@ -137,7 +137,7 @@ class ListStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) + stringEncoder, TimeMode.None()) val listState1: ListState[Long] = handle.getListState[Long]("listState1", Encoders.scalaLong) val listState2: ListState[Long] = handle.getListState[Long]("listState2", Encoders.scalaLong) @@ -167,7 +167,7 @@ class ListStateSuite extends StateVariableSuiteBase { val store = provider.getStore(0) val timestampMs = 10 val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + stringEncoder, TimeMode.ProcessingTime(), batchTimestampMs = Some(timestampMs)) val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1)) @@ -187,7 +187,7 @@ class ListStateSuite extends StateVariableSuiteBase { // increment batchProcessingTime, or watermark and ensure expired value is not returned val nextBatchHandle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + stringEncoder, TimeMode.ProcessingTime(), batchTimestampMs = Some(ttlExpirationMs)) val nextBatchTestState: ListStateImplWithTTL[String] = @@ -223,7 +223,7 @@ class ListStateSuite extends StateVariableSuiteBase { val store = provider.getStore(0) val batchTimestampMs = 10 val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + stringEncoder, TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs)) Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration => @@ -234,7 +234,7 @@ class ListStateSuite extends StateVariableSuiteBase { checkError( ex, - errorClass = "STATEFUL_PROCESSOR_TTL_DURATION_MUST_BE_POSITIVE", + condition = "STATEFUL_PROCESSOR_TTL_DURATION_MUST_BE_POSITIVE", parameters = Map( "operationType" -> "update", "stateName" -> "testState" @@ -250,7 +250,7 @@ class ListStateSuite extends StateVariableSuiteBase { val store = provider.getStore(0) val timestampMs = 10 val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.bean(classOf[POJOTestClass]).asInstanceOf[ExpressionEncoder[Any]], + encoderFor(Encoders.bean(classOf[POJOTestClass])).asInstanceOf[ExpressionEncoder[Any]], TimeMode.ProcessingTime(), batchTimestampMs = Some(timestampMs)) val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala index b5ba25518a5ea..b067d589de904 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala @@ -22,7 +22,6 @@ import java.util.UUID import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.Encoders -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, MapStateImplWithTTL, StatefulProcessorHandleImpl} import org.apache.spark.sql.streaming.{ListState, MapState, TimeMode, TTLConfig, ValueState} import org.apache.spark.sql.types.{BinaryType, StructType} @@ -41,7 +40,7 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) + stringEncoder, TimeMode.None()) val testState: MapState[String, Double] = handle.getMapState[String, Double]("testState", Encoders.STRING, Encoders.scalaDouble) @@ -75,7 +74,7 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) + stringEncoder, TimeMode.None()) val testState1: MapState[Long, Double] = handle.getMapState[Long, Double]("testState1", Encoders.scalaLong, Encoders.scalaDouble) @@ -114,7 +113,7 @@ class MapStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) + stringEncoder, TimeMode.None()) val mapTestState1: MapState[String, Int] = handle.getMapState[String, Int]("mapTestState1", Encoders.STRING, Encoders.scalaInt) @@ -175,7 +174,7 @@ class MapStateSuite extends StateVariableSuiteBase { val store = provider.getStore(0) val timestampMs = 10 val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.ProcessingTime(), + stringEncoder, TimeMode.ProcessingTime(), batchTimestampMs = Some(timestampMs)) val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1)) @@ -196,7 +195,7 @@ class MapStateSuite extends StateVariableSuiteBase { // increment batchProcessingTime, or watermark and ensure expired value is not returned val nextBatchHandle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + stringEncoder, TimeMode.ProcessingTime(), batchTimestampMs = Some(ttlExpirationMs)) val nextBatchTestState: MapStateImplWithTTL[String, String] = @@ -233,7 +232,7 @@ class MapStateSuite extends StateVariableSuiteBase { val store = provider.getStore(0) val batchTimestampMs = 10 val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + stringEncoder, TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs)) Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration => @@ -245,7 +244,7 @@ class MapStateSuite extends StateVariableSuiteBase { checkError( ex, - errorClass = "STATEFUL_PROCESSOR_TTL_DURATION_MUST_BE_POSITIVE", + condition = "STATEFUL_PROCESSOR_TTL_DURATION_MUST_BE_POSITIVE", parameters = Map( "operationType" -> "update", "stateName" -> "testState" @@ -261,7 +260,7 @@ class MapStateSuite extends StateVariableSuiteBase { val store = provider.getStore(0) val timestampMs = 10 val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.ProcessingTime(), + stringEncoder, TimeMode.ProcessingTime(), batchTimestampMs = Some(timestampMs)) val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala index 8fcd6edf1abb7..d20cfb04f8e81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala @@ -119,7 +119,7 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest private def getFormatVersion(query: StreamingQuery): Int = { query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.sparkSession - .conf.get(SQLConf.STATE_STORE_ROCKSDB_FORMAT_VERSION) + .sessionState.conf.getConf(SQLConf.STATE_STORE_ROCKSDB_FORMAT_VERSION) } testWithColumnFamilies("SPARK-36519: store RocksDB format version in the checkpoint", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 0bd86068ca3f6..32467a2dd11bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -174,7 +174,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } checkError( ex1, - errorClass = "STATE_STORE_INCORRECT_NUM_ORDERING_COLS_FOR_RANGE_SCAN", + condition = "STATE_STORE_INCORRECT_NUM_ORDERING_COLS_FOR_RANGE_SCAN", parameters = Map( "numOrderingCols" -> "0" ), @@ -193,7 +193,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } checkError( ex2, - errorClass = "STATE_STORE_INCORRECT_NUM_ORDERING_COLS_FOR_RANGE_SCAN", + condition = "STATE_STORE_INCORRECT_NUM_ORDERING_COLS_FOR_RANGE_SCAN", parameters = Map( "numOrderingCols" -> (keySchemaWithRangeScan.length + 1).toString ), @@ -215,7 +215,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } checkError( ex, - errorClass = "STATE_STORE_VARIABLE_SIZE_ORDERING_COLS_NOT_SUPPORTED", + condition = "STATE_STORE_VARIABLE_SIZE_ORDERING_COLS_NOT_SUPPORTED", parameters = Map( "fieldName" -> keySchemaWithVariableSizeCols.fields(0).name, "index" -> "0" @@ -253,7 +253,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } checkError( ex, - errorClass = "STATE_STORE_VARIABLE_SIZE_ORDERING_COLS_NOT_SUPPORTED", + condition = "STATE_STORE_VARIABLE_SIZE_ORDERING_COLS_NOT_SUPPORTED", parameters = Map( "fieldName" -> field.name, "index" -> index.toString @@ -278,7 +278,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } checkError( ex, - errorClass = "STATE_STORE_NULL_TYPE_ORDERING_COLS_NOT_SUPPORTED", + condition = "STATE_STORE_NULL_TYPE_ORDERING_COLS_NOT_SUPPORTED", parameters = Map( "fieldName" -> keySchemaWithNullTypeCols.fields(0).name, "index" -> "0" @@ -934,7 +934,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid if (!colFamiliesEnabled) { checkError( ex, - errorClass = "STATE_STORE_UNSUPPORTED_OPERATION", + condition = "STATE_STORE_UNSUPPORTED_OPERATION", parameters = Map( "operationType" -> "create_col_family", "entity" -> "multiple column families is disabled in RocksDBStateStoreProvider" @@ -944,7 +944,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } else { checkError( ex, - errorClass = "STATE_STORE_CANNOT_USE_COLUMN_FAMILY_WITH_INVALID_NAME", + condition = "STATE_STORE_CANNOT_USE_COLUMN_FAMILY_WITH_INVALID_NAME", parameters = Map( "operationName" -> "create_col_family", "colFamilyName" -> colFamilyName @@ -962,7 +962,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid newStoreProvider(useColumnFamilies = colFamiliesEnabled)) { provider => val store = provider.getStore(0) - Seq("_internal", "_test", "_test123", "__12345").foreach { colFamilyName => + Seq("$internal", "$test", "$test123", "$_12345", "$$$235").foreach { colFamilyName => val ex = intercept[SparkUnsupportedOperationException] { store.createColFamilyIfAbsent(colFamilyName, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema)) @@ -971,7 +971,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid if (!colFamiliesEnabled) { checkError( ex, - errorClass = "STATE_STORE_UNSUPPORTED_OPERATION", + condition = "STATE_STORE_UNSUPPORTED_OPERATION", parameters = Map( "operationType" -> "create_col_family", "entity" -> "multiple column families is disabled in RocksDBStateStoreProvider" @@ -981,11 +981,11 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } else { checkError( ex, - errorClass = "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS", + condition = "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS", parameters = Map( "colFamilyName" -> colFamilyName ), - matchPVals = true + matchPVals = false ) } } @@ -1073,7 +1073,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } checkError( exception = e.asInstanceOf[StateStoreUnsupportedOperationOnMissingColumnFamily], - errorClass = "STATE_STORE_UNSUPPORTED_OPERATION_ON_MISSING_COLUMN_FAMILY", + condition = "STATE_STORE_UNSUPPORTED_OPERATION_ON_MISSING_COLUMN_FAMILY", sqlState = Some("42802"), parameters = Map("operationType" -> "get", "colFamilyName" -> colFamily1) ) @@ -1221,7 +1221,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid checkError( exception = e.asInstanceOf[StateStoreUnsupportedOperationOnMissingColumnFamily], - errorClass = "STATE_STORE_UNSUPPORTED_OPERATION_ON_MISSING_COLUMN_FAMILY", + condition = "STATE_STORE_UNSUPPORTED_OPERATION_ON_MISSING_COLUMN_FAMILY", sqlState = Some("42802"), parameters = Map("operationType" -> "iterator", "colFamilyName" -> cfName) ) @@ -1241,7 +1241,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid if (!colFamiliesEnabled) { checkError( ex, - errorClass = "STATE_STORE_UNSUPPORTED_OPERATION", + condition = "STATE_STORE_UNSUPPORTED_OPERATION", parameters = Map( "operationType" -> operationName, "entity" -> "multiple column families is disabled in RocksDBStateStoreProvider" @@ -1251,7 +1251,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } else { checkError( ex, - errorClass = "STATE_STORE_UNSUPPORTED_OPERATION_ON_MISSING_COLUMN_FAMILY", + condition = "STATE_STORE_UNSUPPORTED_OPERATION_ON_MISSING_COLUMN_FAMILY", parameters = Map( "operationType" -> operationName, "colFamilyName" -> colFamilyName diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index d07ce07c41e5c..608a22a284b6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -34,7 +34,7 @@ import org.rocksdb.CompressionType import org.scalactic.source.Position import org.scalatest.Tag -import org.apache.spark.SparkException +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.execution.streaming.{CreateAtomicTestManager, FileSystemBasedCheckpointFileManager} import org.apache.spark.sql.execution.streaming.CheckpointFileManager.{CancellableFSDataOutputStream, RenameBasedFSDataOutputStream} @@ -167,7 +167,10 @@ trait AlsoTestWithChangelogCheckpointingEnabled @SlowSQLTest class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with SharedSparkSession { - sqlConf.setConf(SQLConf.STATE_STORE_PROVIDER_CLASS, classOf[RocksDBStateStoreProvider].getName) + override protected def sparkConf: SparkConf = { + super.sparkConf + .set(SQLConf.STATE_STORE_PROVIDER_CLASS, classOf[RocksDBStateStoreProvider].getName) + } testWithColumnFamilies( "RocksDB: check changelog and snapshot version", @@ -202,7 +205,7 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared } checkError( ex, - errorClass = "CANNOT_LOAD_STATE_STORE.UNEXPECTED_VERSION", + condition = "CANNOT_LOAD_STATE_STORE.UNEXPECTED_VERSION", parameters = Map("version" -> "-1") ) ex = intercept[SparkException] { @@ -210,7 +213,7 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared } checkError( ex, - errorClass = "CANNOT_LOAD_STATE_STORE.UNEXPECTED_VERSION", + condition = "CANNOT_LOAD_STATE_STORE.UNEXPECTED_VERSION", parameters = Map("version" -> "-1") ) @@ -222,7 +225,7 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared } checkError( ex, - errorClass = "CANNOT_LOAD_STATE_STORE.CANNOT_READ_STREAMING_STATE_FILE", + condition = "CANNOT_LOAD_STATE_STORE.CANNOT_READ_STREAMING_STATE_FILE", parameters = Map( "fileToRead" -> s"$remoteDir/1.changelog" ) @@ -808,6 +811,47 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared } } + testWithChangelogCheckpointingEnabled("RocksDB: ensure that changelog files are written " + + "and snapshots uploaded optionally with changelog format v2") { + withTempDir { dir => + val remoteDir = Utils.createTempDir().toString + val conf = dbConf.copy(minDeltasForSnapshot = 5, compactOnCommit = false) + new File(remoteDir).delete() // to make sure that the directory gets created + withDB(remoteDir, conf = conf, useColumnFamilies = true) { db => + db.createColFamilyIfAbsent("test") + db.load(0) + db.put("a", "1") + db.put("b", "2") + db.commit() + assert(changelogVersionsPresent(remoteDir) == Seq(1)) + assert(snapshotVersionsPresent(remoteDir) == Seq(1)) + + db.load(1) + db.put("a", "3") + db.put("c", "4") + db.commit() + + assert(changelogVersionsPresent(remoteDir) == Seq(1, 2)) + assert(snapshotVersionsPresent(remoteDir) == Seq(1)) + + db.removeColFamilyIfExists("test") + db.load(2) + db.remove("a") + db.put("d", "5") + db.commit() + assert(changelogVersionsPresent(remoteDir) == Seq(1, 2, 3)) + assert(snapshotVersionsPresent(remoteDir) == Seq(1, 3)) + + db.load(3) + db.put("e", "6") + db.remove("b") + db.commit() + assert(changelogVersionsPresent(remoteDir) == Seq(1, 2, 3, 4)) + assert(snapshotVersionsPresent(remoteDir) == Seq(1, 3)) + } + } + } + test("RocksDB: ensure merge operation correctness") { withTempDir { dir => val remoteDir = Utils.createTempDir().toString @@ -1107,7 +1151,7 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared } checkError( ex, - errorClass = "CANNOT_LOAD_STATE_STORE.UNRELEASED_THREAD_ERROR", + condition = "CANNOT_LOAD_STATE_STORE.UNRELEASED_THREAD_ERROR", parameters = Map( "loggingId" -> "\\[Thread-\\d+\\]", "operationType" -> "load_store", @@ -1135,7 +1179,7 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared } checkError( ex, - errorClass = "CANNOT_LOAD_STATE_STORE.UNRELEASED_THREAD_ERROR", + condition = "CANNOT_LOAD_STATE_STORE.UNRELEASED_THREAD_ERROR", parameters = Map( "loggingId" -> "\\[Thread-\\d+\\]", "operationType" -> "load_store", @@ -1187,7 +1231,7 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared } checkError( e, - errorClass = "CANNOT_LOAD_STATE_STORE.CANNOT_READ_CHECKPOINT", + condition = "CANNOT_LOAD_STATE_STORE.CANNOT_READ_CHECKPOINT", parameters = Map( "expectedVersion" -> "v2", "actualVersion" -> "v1" @@ -2157,9 +2201,7 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared } } - private def sqlConf = SQLConf.get.clone() - - private def dbConf = RocksDBConf(StateStoreConf(sqlConf)) + private def dbConf = RocksDBConf(StateStoreConf(SQLConf.get.clone())) def withDB[T]( remoteDir: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index b41fb91dd5d01..8bbc7a31760d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -175,7 +175,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } checkError( ex, - errorClass = "UNSUPPORTED_FEATURE.STATE_STORE_MULTIPLE_COLUMN_FAMILIES", + condition = "UNSUPPORTED_FEATURE.STATE_STORE_MULTIPLE_COLUMN_FAMILIES", parameters = Map( "stateStoreProvider" -> "HDFSBackedStateStoreProvider" ), @@ -187,7 +187,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } checkError( ex, - errorClass = "STATE_STORE_UNSUPPORTED_OPERATION", + condition = "STATE_STORE_UNSUPPORTED_OPERATION", parameters = Map( "operationType" -> operationName, "entity" -> "HDFSBackedStateStoreProvider" @@ -241,7 +241,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } checkError( ex, - errorClass = "STATE_STORE_UNSUPPORTED_OPERATION", + condition = "STATE_STORE_UNSUPPORTED_OPERATION", parameters = Map( "operationType" -> "Range scan", "entity" -> "HDFSBackedStateStoreProvider" @@ -373,7 +373,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } checkError( e, - errorClass = "CANNOT_LOAD_STATE_STORE.UNCATEGORIZED", + condition = "CANNOT_LOAD_STATE_STORE.UNCATEGORIZED", parameters = Map.empty ) @@ -385,7 +385,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } checkError( e, - errorClass = "CANNOT_LOAD_STATE_STORE.UNCATEGORIZED", + condition = "CANNOT_LOAD_STATE_STORE.UNCATEGORIZED", parameters = Map.empty ) @@ -396,7 +396,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } checkError( e, - errorClass = "CANNOT_LOAD_STATE_STORE.CANNOT_READ_DELTA_FILE_NOT_EXISTS", + condition = "CANNOT_LOAD_STATE_STORE.CANNOT_READ_DELTA_FILE_NOT_EXISTS", parameters = Map( "fileToRead" -> s"${provider.stateStoreId.storeCheckpointLocation()}/1.delta", "clazz" -> s"${provider.toString()}" @@ -1273,21 +1273,21 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] if (version < 0) { checkError( e, - errorClass = "CANNOT_LOAD_STATE_STORE.UNEXPECTED_VERSION", + condition = "CANNOT_LOAD_STATE_STORE.UNEXPECTED_VERSION", parameters = Map("version" -> version.toString) ) } else { if (isHDFSBackedStoreProvider) { checkError( e, - errorClass = "CANNOT_LOAD_STATE_STORE.CANNOT_READ_DELTA_FILE_NOT_EXISTS", + condition = "CANNOT_LOAD_STATE_STORE.CANNOT_READ_DELTA_FILE_NOT_EXISTS", parameters = Map("fileToRead" -> ".*", "clazz" -> ".*"), matchPVals = true ) } else { checkError( e, - errorClass = "CANNOT_LOAD_STATE_STORE.CANNOT_READ_STREAMING_STATE_FILE", + condition = "CANNOT_LOAD_STATE_STORE.CANNOT_READ_STREAMING_STATE_FILE", parameters = Map("fileToRead" -> ".*"), matchPVals = true ) @@ -1478,7 +1478,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] } checkError( e, - errorClass = "CANNOT_LOAD_STATE_STORE.UNEXPECTED_VERSION", + condition = "CANNOT_LOAD_STATE_STORE.UNEXPECTED_VERSION", parameters = Map( "version" -> "-1" ) @@ -1493,7 +1493,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] } checkError( e, - errorClass = "CANNOT_LOAD_STATE_STORE.CANNOT_READ_DELTA_FILE_NOT_EXISTS", + condition = "CANNOT_LOAD_STATE_STORE.CANNOT_READ_DELTA_FILE_NOT_EXISTS", parameters = Map( "fileToRead" -> s"$dir/0/0/1.delta", "clazz" -> "HDFSStateStoreProvider\\[.+\\]" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala index 52bdb0213c7e5..48a6fd836a462 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala @@ -22,7 +22,6 @@ import java.util.UUID import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.Encoders -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} import org.apache.spark.sql.streaming.{TimeMode, TTLConfig} @@ -33,9 +32,6 @@ import org.apache.spark.sql.streaming.{TimeMode, TTLConfig} */ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { - private def keyExprEncoder: ExpressionEncoder[Any] = - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]] - private def getTimeMode(timeMode: String): TimeMode = { timeMode match { case "None" => TimeMode.None() @@ -50,7 +46,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode)) + UUID.randomUUID(), stringEncoder, getTimeMode(timeMode)) assert(handle.getHandleState === StatefulProcessorHandleState.CREATED) handle.getValueState[Long]("testState", Encoders.scalaLong) } @@ -68,7 +64,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { } checkError( ex, - errorClass = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE", + condition = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE", parameters = Map( "operationType" -> operationType, "handleState" -> handleState.toString @@ -91,7 +87,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode)) + UUID.randomUUID(), stringEncoder, getTimeMode(timeMode)) Seq(StatefulProcessorHandleState.INITIALIZED, StatefulProcessorHandleState.DATA_PROCESSED, @@ -109,14 +105,14 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, TimeMode.None()) + UUID.randomUUID(), stringEncoder, TimeMode.None()) val ex = intercept[SparkUnsupportedOperationException] { handle.registerTimer(10000L) } checkError( ex, - errorClass = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_TIME_MODE", + condition = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_TIME_MODE", parameters = Map( "operationType" -> "register_timer", "timeMode" -> TimeMode.None().toString @@ -130,7 +126,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { checkError( ex2, - errorClass = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_TIME_MODE", + condition = "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_TIME_MODE", parameters = Map( "operationType" -> "delete_timer", "timeMode" -> TimeMode.None().toString @@ -145,7 +141,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode)) + UUID.randomUUID(), stringEncoder, getTimeMode(timeMode)) handle.setHandleState(StatefulProcessorHandleState.INITIALIZED) assert(handle.getHandleState === StatefulProcessorHandleState.INITIALIZED) @@ -166,7 +162,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode)) + UUID.randomUUID(), stringEncoder, getTimeMode(timeMode)) handle.setHandleState(StatefulProcessorHandleState.DATA_PROCESSED) assert(handle.getHandleState === StatefulProcessorHandleState.DATA_PROCESSED) @@ -206,7 +202,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, getTimeMode(timeMode)) + UUID.randomUUID(), stringEncoder, getTimeMode(timeMode)) Seq(StatefulProcessorHandleState.CREATED, StatefulProcessorHandleState.TIMER_PROCESSED, @@ -223,7 +219,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, TimeMode.ProcessingTime(), + UUID.randomUUID(), stringEncoder, TimeMode.ProcessingTime(), batchTimestampMs = Some(10)) val valueStateWithTTL = handle.getValueState("testState", @@ -241,7 +237,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, TimeMode.ProcessingTime(), + UUID.randomUUID(), stringEncoder, TimeMode.ProcessingTime(), batchTimestampMs = Some(10)) val listStateWithTTL = handle.getListState("testState", @@ -259,7 +255,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, TimeMode.ProcessingTime(), + UUID.randomUUID(), stringEncoder, TimeMode.ProcessingTime(), batchTimestampMs = Some(10)) val mapStateWithTTL = handle.getMapState("testState", @@ -277,7 +273,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), keyExprEncoder, TimeMode.None()) + UUID.randomUUID(), stringEncoder, TimeMode.None()) handle.getValueState("testValueState", Encoders.STRING) handle.getListState("testListState", Encoders.STRING) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala index df6a3fd7b23e5..24a120be9d9af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.Encoders -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, TimerStateImpl} import org.apache.spark.sql.streaming.TimeMode @@ -45,7 +45,7 @@ class TimerSuite extends StateVariableSuiteBase { ImplicitGroupingKeyTracker.setImplicitKey("test_key") val timerState = new TimerStateImpl(store, timeMode, - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + stringEncoder) timerState.registerTimer(1L * 1000) assert(timerState.listTimers().toSet === Set(1000L)) assert(timerState.getExpiredTimers(Long.MaxValue).toSeq === Seq(("test_key", 1000L))) @@ -64,9 +64,9 @@ class TimerSuite extends StateVariableSuiteBase { ImplicitGroupingKeyTracker.setImplicitKey("test_key") val timerState1 = new TimerStateImpl(store, timeMode, - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + stringEncoder) val timerState2 = new TimerStateImpl(store, timeMode, - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + stringEncoder) timerState1.registerTimer(1L * 1000) timerState2.registerTimer(15L * 1000) assert(timerState1.listTimers().toSet === Set(15000L, 1000L)) @@ -89,7 +89,7 @@ class TimerSuite extends StateVariableSuiteBase { ImplicitGroupingKeyTracker.setImplicitKey("test_key1") val timerState1 = new TimerStateImpl(store, timeMode, - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + stringEncoder) timerState1.registerTimer(1L * 1000) timerState1.registerTimer(2L * 1000) assert(timerState1.listTimers().toSet === Set(1000L, 2000L)) @@ -97,7 +97,7 @@ class TimerSuite extends StateVariableSuiteBase { ImplicitGroupingKeyTracker.setImplicitKey("test_key2") val timerState2 = new TimerStateImpl(store, timeMode, - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + stringEncoder) timerState2.registerTimer(15L * 1000) ImplicitGroupingKeyTracker.removeImplicitKey() @@ -122,7 +122,7 @@ class TimerSuite extends StateVariableSuiteBase { ImplicitGroupingKeyTracker.setImplicitKey("test_key") val timerState = new TimerStateImpl(store, timeMode, - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + stringEncoder) val timerTimerstamps = Seq(931L, 8000L, 452300L, 4200L, 90L, 1L, 2L, 8L, 3L, 35L, 6L, 9L, 5L) // register/put unordered timestamp into rocksDB timerTimerstamps.foreach(timerState.registerTimer) @@ -141,19 +141,19 @@ class TimerSuite extends StateVariableSuiteBase { ImplicitGroupingKeyTracker.setImplicitKey("test_key1") val timerState1 = new TimerStateImpl(store, timeMode, - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + stringEncoder) val timerTimestamps1 = Seq(64L, 32L, 1024L, 4096L, 0L, 1L) timerTimestamps1.foreach(timerState1.registerTimer) val timerState2 = new TimerStateImpl(store, timeMode, - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + stringEncoder) val timerTimestamps2 = Seq(931L, 8000L, 452300L, 4200L) timerTimestamps2.foreach(timerState2.registerTimer) ImplicitGroupingKeyTracker.removeImplicitKey() ImplicitGroupingKeyTracker.setImplicitKey("test_key3") val timerState3 = new TimerStateImpl(store, timeMode, - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + stringEncoder) val timerTimerStamps3 = Seq(1L, 2L, 8L, 3L) timerTimerStamps3.foreach(timerState3.registerTimer) ImplicitGroupingKeyTracker.removeImplicitKey() @@ -171,7 +171,7 @@ class TimerSuite extends StateVariableSuiteBase { val store = provider.getStore(0) ImplicitGroupingKeyTracker.setImplicitKey(TestClass(1L, "k1")) val timerState = new TimerStateImpl(store, timeMode, - Encoders.product[TestClass].asInstanceOf[ExpressionEncoder[Any]]) + encoderFor(Encoders.product[TestClass]).asInstanceOf[ExpressionEncoder[Any]]) timerState.registerTimer(1L * 1000) assert(timerState.listTimers().toSet === Set(1000L)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index 7d5b3e4a6bdc9..13d758eb1b88f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkException, SparkUnsupportedOperationException} import org.apache.spark.sql.Encoders -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, ValueStateImplWithTTL} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{TimeMode, TTLConfig, ValueState} @@ -49,7 +49,7 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) + stringEncoder, TimeMode.None()) val stateName = "testState" val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) @@ -61,7 +61,7 @@ class ValueStateSuite extends StateVariableSuiteBase { assert(ex.isInstanceOf[SparkException]) checkError( ex.asInstanceOf[SparkException], - errorClass = "INTERNAL_ERROR_TWS", + condition = "INTERNAL_ERROR_TWS", parameters = Map( "message" -> s"Implicit key not found in state store for stateName=$stateName" ), @@ -80,7 +80,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } checkError( ex1.asInstanceOf[SparkException], - errorClass = "INTERNAL_ERROR_TWS", + condition = "INTERNAL_ERROR_TWS", parameters = Map( "message" -> s"Implicit key not found in state store for stateName=$stateName" ), @@ -93,7 +93,7 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) + stringEncoder, TimeMode.None()) val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) ImplicitGroupingKeyTracker.setImplicitKey("test_key") @@ -119,7 +119,7 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) + stringEncoder, TimeMode.None()) val testState1: ValueState[Long] = handle.getValueState[Long]( "testState1", Encoders.scalaLong) @@ -164,19 +164,19 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, - UUID.randomUUID(), Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) + UUID.randomUUID(), stringEncoder, TimeMode.None()) - val cfName = "_testState" + val cfName = "$testState" val ex = intercept[SparkUnsupportedOperationException] { handle.getValueState[Long](cfName, Encoders.scalaLong) } checkError( ex, - errorClass = "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS", + condition = "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS", parameters = Map( "colFamilyName" -> cfName ), - matchPVals = true + matchPVals = false ) } } @@ -192,7 +192,7 @@ class ValueStateSuite extends StateVariableSuiteBase { } checkError( ex, - errorClass = "UNSUPPORTED_FEATURE.STATE_STORE_MULTIPLE_COLUMN_FAMILIES", + condition = "UNSUPPORTED_FEATURE.STATE_STORE_MULTIPLE_COLUMN_FAMILIES", parameters = Map( "stateStoreProvider" -> "HDFSBackedStateStoreProvider" ), @@ -204,7 +204,7 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) + stringEncoder, TimeMode.None()) val testState: ValueState[Double] = handle.getValueState[Double]("testState", Encoders.scalaDouble) @@ -230,7 +230,7 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) + stringEncoder, TimeMode.None()) val testState: ValueState[Long] = handle.getValueState[Long]("testState", Encoders.scalaLong) @@ -256,7 +256,7 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) + stringEncoder, TimeMode.None()) val testState: ValueState[TestClass] = handle.getValueState[TestClass]("testState", Encoders.product[TestClass]) @@ -282,7 +282,7 @@ class ValueStateSuite extends StateVariableSuiteBase { tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.None()) + stringEncoder, TimeMode.None()) val testState: ValueState[POJOTestClass] = handle.getValueState[POJOTestClass]("testState", Encoders.bean(classOf[POJOTestClass])) @@ -310,7 +310,7 @@ class ValueStateSuite extends StateVariableSuiteBase { val store = provider.getStore(0) val timestampMs = 10 val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeMode.ProcessingTime(), + stringEncoder, TimeMode.ProcessingTime(), batchTimestampMs = Some(timestampMs)) val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1)) @@ -330,7 +330,7 @@ class ValueStateSuite extends StateVariableSuiteBase { // increment batchProcessingTime, or watermark and ensure expired value is not returned val nextBatchHandle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + stringEncoder, TimeMode.ProcessingTime(), batchTimestampMs = Some(ttlExpirationMs)) val nextBatchTestState: ValueStateImplWithTTL[String] = @@ -366,7 +366,7 @@ class ValueStateSuite extends StateVariableSuiteBase { val store = provider.getStore(0) val batchTimestampMs = 10 val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], + stringEncoder, TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs)) Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration => @@ -377,7 +377,7 @@ class ValueStateSuite extends StateVariableSuiteBase { checkError( ex, - errorClass = "STATEFUL_PROCESSOR_TTL_DURATION_MUST_BE_POSITIVE", + condition = "STATEFUL_PROCESSOR_TTL_DURATION_MUST_BE_POSITIVE", parameters = Map( "operationType" -> "update", "stateName" -> "testState" @@ -393,8 +393,8 @@ class ValueStateSuite extends StateVariableSuiteBase { val store = provider.getStore(0) val timestampMs = 10 val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), - Encoders.product[TestClass].asInstanceOf[ExpressionEncoder[Any]], TimeMode.ProcessingTime(), - batchTimestampMs = Some(timestampMs)) + encoderFor(Encoders.product[TestClass]).asInstanceOf[ExpressionEncoder[Any]], + TimeMode.ProcessingTime(), batchTimestampMs = Some(timestampMs)) val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1)) val testState: ValueStateImplWithTTL[POJOTestClass] = @@ -437,6 +437,8 @@ abstract class StateVariableSuiteBase extends SharedSparkSession import StateStoreTestsHelper._ + protected val stringEncoder = encoderFor(Encoders.STRING).asInstanceOf[ExpressionEncoder[Any]] + // dummy schema for initializing rocksdb provider protected def schemaForKeyRow: StructType = new StructType().add("key", BinaryType) protected def schemaForValueRow: StructType = new StructType().add("value", BinaryType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/UISeleniumSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/UISeleniumSuite.scala index 111e233c04e32..8c10d646e935c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/UISeleniumSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/UISeleniumSuite.scala @@ -127,7 +127,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser { exception = intercept[SparkRuntimeException] { spark.sql(s"SELECT raise_error('$errorMsg')").collect() }, - errorClass = "USER_RAISED_EXCEPTION", + condition = "USER_RAISED_EXCEPTION", parameters = Map("errorMessage" -> escape)) eventually(timeout(10.seconds), interval(100.milliseconds)) { val summary = findErrorSummaryOnSQLUI() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index e92428f371e05..6eff610433c9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -67,7 +67,7 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { new ExpressionInfo( "testClass", null, "testName", null, "", "", "", invalidGroupName, "", "", "") }, - errorClass = "_LEGACY_ERROR_TEMP_3202", + condition = "_LEGACY_ERROR_TEMP_3202", parameters = Map( "exprName" -> "testName", "group" -> invalidGroupName, @@ -91,7 +91,7 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { new ExpressionInfo( "testClass", null, "testName", null, "", "", "", "", "", "", invalidSource) }, - errorClass = "_LEGACY_ERROR_TEMP_3203", + condition = "_LEGACY_ERROR_TEMP_3203", parameters = Map( "exprName" -> "testName", "source" -> invalidSource, @@ -104,7 +104,7 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { exception = intercept[SparkIllegalArgumentException] { new ExpressionInfo("testClass", null, "testName", null, "", "", invalidNote, "", "", "", "") }, - errorClass = "_LEGACY_ERROR_TEMP_3201", + condition = "_LEGACY_ERROR_TEMP_3201", parameters = Map("exprName" -> "testName", "note" -> invalidNote)) val invalidSince = "-3.0.0" @@ -113,7 +113,7 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { new ExpressionInfo( "testClass", null, "testName", null, "", "", "", "", invalidSince, "", "") }, - errorClass = "_LEGACY_ERROR_TEMP_3204", + condition = "_LEGACY_ERROR_TEMP_3204", parameters = Map("since" -> invalidSince, "exprName" -> "testName")) val invalidDeprecated = " invalid deprecated" @@ -122,7 +122,7 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { new ExpressionInfo( "testClass", null, "testName", null, "", "", "", "", "", invalidDeprecated, "") }, - errorClass = "_LEGACY_ERROR_TEMP_3205", + condition = "_LEGACY_ERROR_TEMP_3205", parameters = Map("exprName" -> "testName", "deprecated" -> invalidDeprecated)) } @@ -243,7 +243,7 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { // Examples can change settings. We clone the session to prevent tests clashing. val clonedSpark = spark.cloneSession() // Coalescing partitions can change result order, so disable it. - clonedSpark.conf.set(SQLConf.COALESCE_PARTITIONS_ENABLED, false) + clonedSpark.conf.set(SQLConf.COALESCE_PARTITIONS_ENABLED.key, false) val info = clonedSpark.sessionState.catalog.lookupFunctionInfo(funcId) val className = info.getClassName if (!ignoreSet.contains(className)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 7c929b5da872a..401b17d2b24a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -700,7 +700,8 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf val description = "this is a test table" withTable("t") { - withTempDir { dir => + withTempDir { baseDir => + val dir = new File(baseDir, "test%prefix") spark.catalog.createTable( tableName = "t", source = "json", @@ -778,7 +779,7 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf exception = intercept[AnalysisException] { spark.catalog.recoverPartitions("my_temp_table") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> "`my_temp_table`", "operation" -> "recoverPartitions()") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 05cd9800bdf21..82795e551b6bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -47,7 +47,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { // Set a conf first. spark.conf.set(testKey, testVal) // Clear the conf. - spark.sessionState.conf.clear() + sqlConf.clear() // After clear, only overrideConfs used by unit test should be in the SQLConf. assert(spark.conf.getAll === TestSQLContext.overrideConfs) @@ -62,11 +62,11 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { assert(spark.conf.get(testKey, testVal + "_") === testVal) assert(spark.conf.getAll.contains(testKey)) - spark.sessionState.conf.clear() + sqlConf.clear() } test("parse SQL set commands") { - spark.sessionState.conf.clear() + sqlConf.clear() sql(s"set $testKey=$testVal") assert(spark.conf.get(testKey, testVal + "_") === testVal) assert(spark.conf.get(testKey, testVal + "_") === testVal) @@ -84,11 +84,11 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { sql(s"set $key=") assert(spark.conf.get(key, "0") === "") - spark.sessionState.conf.clear() + sqlConf.clear() } test("set command for display") { - spark.sessionState.conf.clear() + sqlConf.clear() checkAnswer( sql("SET").where("key = 'spark.sql.groupByOrdinal'").select("key", "value"), Nil) @@ -109,11 +109,11 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { } test("deprecated property") { - spark.sessionState.conf.clear() - val original = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS) + sqlConf.clear() + val original = sqlConf.getConf(SQLConf.SHUFFLE_PARTITIONS) try { sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(spark.conf.get(SQLConf.SHUFFLE_PARTITIONS) === 10) + assert(sqlConf.getConf(SQLConf.SHUFFLE_PARTITIONS) === 10) } finally { sql(s"set ${SQLConf.SHUFFLE_PARTITIONS.key}=$original") } @@ -146,18 +146,18 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { } test("reset - public conf") { - spark.sessionState.conf.clear() - val original = spark.conf.get(SQLConf.GROUP_BY_ORDINAL) + sqlConf.clear() + val original = sqlConf.getConf(SQLConf.GROUP_BY_ORDINAL) try { - assert(spark.conf.get(SQLConf.GROUP_BY_ORDINAL)) + assert(sqlConf.getConf(SQLConf.GROUP_BY_ORDINAL)) sql(s"set ${SQLConf.GROUP_BY_ORDINAL.key}=false") - assert(spark.conf.get(SQLConf.GROUP_BY_ORDINAL) === false) + assert(sqlConf.getConf(SQLConf.GROUP_BY_ORDINAL) === false) assert(sql(s"set").where(s"key = '${SQLConf.GROUP_BY_ORDINAL.key}'").count() == 1) - assert(spark.conf.get(SQLConf.OPTIMIZER_EXCLUDED_RULES).isEmpty) + assert(sqlConf.getConf(SQLConf.OPTIMIZER_EXCLUDED_RULES).isEmpty) sql(s"reset") - assert(spark.conf.get(SQLConf.GROUP_BY_ORDINAL)) + assert(sqlConf.getConf(SQLConf.GROUP_BY_ORDINAL)) assert(sql(s"set").where(s"key = '${SQLConf.GROUP_BY_ORDINAL.key}'").count() == 0) - assert(spark.conf.get(SQLConf.OPTIMIZER_EXCLUDED_RULES) === + assert(sqlConf.getConf(SQLConf.OPTIMIZER_EXCLUDED_RULES) === Some("org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation")) } finally { sql(s"set ${SQLConf.GROUP_BY_ORDINAL.key}=$original") @@ -165,15 +165,15 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { } test("reset - internal conf") { - spark.sessionState.conf.clear() - val original = spark.conf.get(SQLConf.OPTIMIZER_MAX_ITERATIONS) + sqlConf.clear() + val original = sqlConf.getConf(SQLConf.OPTIMIZER_MAX_ITERATIONS) try { - assert(spark.conf.get(SQLConf.OPTIMIZER_MAX_ITERATIONS) === 100) + assert(sqlConf.getConf(SQLConf.OPTIMIZER_MAX_ITERATIONS) === 100) sql(s"set ${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}=10") - assert(spark.conf.get(SQLConf.OPTIMIZER_MAX_ITERATIONS) === 10) + assert(sqlConf.getConf(SQLConf.OPTIMIZER_MAX_ITERATIONS) === 10) assert(sql(s"set").where(s"key = '${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}'").count() == 1) sql(s"reset") - assert(spark.conf.get(SQLConf.OPTIMIZER_MAX_ITERATIONS) === 100) + assert(sqlConf.getConf(SQLConf.OPTIMIZER_MAX_ITERATIONS) === 100) assert(sql(s"set").where(s"key = '${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}'").count() == 0) } finally { sql(s"set ${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}=$original") @@ -181,7 +181,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { } test("reset - user-defined conf") { - spark.sessionState.conf.clear() + sqlConf.clear() val userDefinedConf = "x.y.z.reset" try { assert(spark.conf.getOption(userDefinedConf).isEmpty) @@ -196,7 +196,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { } test("SPARK-32406: reset - single configuration") { - spark.sessionState.conf.clear() + sqlConf.clear() // spark core conf w/o entry registered val appId = spark.sparkContext.getConf.getAppId sql("RESET spark.app.id") @@ -204,7 +204,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { // spark core conf w/ entry registered checkError( exception = intercept[AnalysisException](sql("RESET spark.executor.cores")), - errorClass = "CANNOT_MODIFY_CONFIG", + condition = "CANNOT_MODIFY_CONFIG", parameters = Map("key" -> "\"spark.executor.cores\"", "docroot" -> SPARK_DOC_ROOT) ) @@ -216,24 +216,24 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { sql("RESET spark.abc") // ignore nonexistent keys // runtime sql configs - val original = spark.conf.get(SQLConf.GROUP_BY_ORDINAL) + val original = sqlConf.getConf(SQLConf.GROUP_BY_ORDINAL) sql(s"SET ${SQLConf.GROUP_BY_ORDINAL.key}=false") sql(s"RESET ${SQLConf.GROUP_BY_ORDINAL.key}") - assert(spark.conf.get(SQLConf.GROUP_BY_ORDINAL) === original) + assert(sqlConf.getConf(SQLConf.GROUP_BY_ORDINAL) === original) // runtime sql configs with optional defaults - assert(spark.conf.get(SQLConf.OPTIMIZER_EXCLUDED_RULES).isEmpty) + assert(sqlConf.getConf(SQLConf.OPTIMIZER_EXCLUDED_RULES).isEmpty) sql(s"RESET ${SQLConf.OPTIMIZER_EXCLUDED_RULES.key}") - assert(spark.conf.get(SQLConf.OPTIMIZER_EXCLUDED_RULES) === + assert(sqlConf.getConf(SQLConf.OPTIMIZER_EXCLUDED_RULES) === Some("org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation")) sql(s"SET ${SQLConf.PLAN_CHANGE_LOG_RULES.key}=abc") sql(s"RESET ${SQLConf.PLAN_CHANGE_LOG_RULES.key}") - assert(spark.conf.get(SQLConf.PLAN_CHANGE_LOG_RULES).isEmpty) + assert(sqlConf.getConf(SQLConf.PLAN_CHANGE_LOG_RULES).isEmpty) // static sql configs checkError( exception = intercept[AnalysisException](sql(s"RESET ${StaticSQLConf.WAREHOUSE_PATH.key}")), - errorClass = "_LEGACY_ERROR_TEMP_1325", + condition = "_LEGACY_ERROR_TEMP_1325", parameters = Map("key" -> "spark.sql.warehouse.dir")) } @@ -247,19 +247,19 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { } test("Test ADVISORY_PARTITION_SIZE_IN_BYTES's method") { - spark.sessionState.conf.clear() + sqlConf.clear() spark.conf.set(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key, "100") - assert(spark.conf.get(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES) === 100) + assert(sqlConf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES) === 100) spark.conf.set(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key, "1k") - assert(spark.conf.get(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES) === 1024) + assert(sqlConf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES) === 1024) spark.conf.set(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key, "1M") - assert(spark.conf.get(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES) === 1048576) + assert(sqlConf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES) === 1048576) spark.conf.set(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key, "1g") - assert(spark.conf.get(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES) === 1073741824) + assert(sqlConf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES) === 1073741824) // test negative value intercept[IllegalArgumentException] { @@ -277,7 +277,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { spark.conf.set(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key, "-90000000000g") } - spark.sessionState.conf.clear() + sqlConf.clear() } test("SparkSession can access configs set in SparkConf") { @@ -305,7 +305,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { try { sparkContext.conf.set(GLOBAL_TEMP_DATABASE, "a") val newSession = new SparkSession(sparkContext) - assert(newSession.conf.get(GLOBAL_TEMP_DATABASE) == "a") + assert(newSession.sessionState.conf.getConf(GLOBAL_TEMP_DATABASE) == "a") checkAnswer( newSession.sql(s"SET ${GLOBAL_TEMP_DATABASE.key}"), Row(GLOBAL_TEMP_DATABASE.key, "a")) @@ -338,16 +338,16 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { } test("SPARK-10365: PARQUET_OUTPUT_TIMESTAMP_TYPE") { - spark.sessionState.conf.clear() + sqlConf.clear() // check default value assert(spark.sessionState.conf.parquetOutputTimestampType == SQLConf.ParquetOutputTimestampType.INT96) - spark.conf.set(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE, "timestamp_micros") + sqlConf.setConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE, "timestamp_micros") assert(spark.sessionState.conf.parquetOutputTimestampType == SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS) - spark.conf.set(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE, "int96") + sqlConf.setConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE, "int96") assert(spark.sessionState.conf.parquetOutputTimestampType == SQLConf.ParquetOutputTimestampType.INT96) @@ -356,7 +356,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { spark.conf.set(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, "invalid") } - spark.sessionState.conf.clear() + sqlConf.clear() } test("SPARK-22779: correctly compute default value for fallback configs") { @@ -373,10 +373,10 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { .get assert(displayValue === fallback.defaultValueString) - spark.conf.set(SQLConf.PARQUET_COMPRESSION, GZIP.lowerCaseName()) + sqlConf.setConf(SQLConf.PARQUET_COMPRESSION, GZIP.lowerCaseName()) assert(spark.conf.get(fallback.key) === GZIP.lowerCaseName()) - spark.conf.set(fallback, LZO.lowerCaseName()) + sqlConf.setConf(fallback, LZO.lowerCaseName()) assert(spark.conf.get(fallback.key) === LZO.lowerCaseName()) val newDisplayValue = spark.sessionState.conf.getAllDefinedConfs @@ -450,7 +450,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkIllegalArgumentException] { spark.conf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, invalidTz) }, - errorClass = "INVALID_CONF_VALUE.TIME_ZONE", + condition = "INVALID_CONF_VALUE.TIME_ZONE", parameters = Map( "confValue" -> invalidTz, "confName" -> SQLConf.SESSION_LOCAL_TIMEZONE.key)) @@ -459,24 +459,24 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { test("set time zone") { TimeZone.getAvailableIDs().foreach { zid => sql(s"set time zone '$zid'") - assert(spark.conf.get(SQLConf.SESSION_LOCAL_TIMEZONE) === zid) + assert(sqlConf.getConf(SQLConf.SESSION_LOCAL_TIMEZONE) === zid) } sql("set time zone local") - assert(spark.conf.get(SQLConf.SESSION_LOCAL_TIMEZONE) === TimeZone.getDefault.getID) + assert(sqlConf.getConf(SQLConf.SESSION_LOCAL_TIMEZONE) === TimeZone.getDefault.getID) val tz = "Invalid TZ" checkError( exception = intercept[SparkIllegalArgumentException] { sql(s"SET TIME ZONE '$tz'").collect() }, - errorClass = "INVALID_CONF_VALUE.TIME_ZONE", + condition = "INVALID_CONF_VALUE.TIME_ZONE", parameters = Map( "confValue" -> tz, "confName" -> SQLConf.SESSION_LOCAL_TIMEZONE.key)) (-18 to 18).map(v => (v, s"interval '$v' hours")).foreach { case (i, interval) => sql(s"set time zone $interval") - val zone = spark.conf.get(SQLConf.SESSION_LOCAL_TIMEZONE) + val zone = sqlConf.getConf(SQLConf.SESSION_LOCAL_TIMEZONE) if (i == 0) { assert(zone === "Z") } else { @@ -486,7 +486,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { val sqlText = "set time zone interval 19 hours" checkError( exception = intercept[ParseException](sql(sqlText)), - errorClass = "_LEGACY_ERROR_TEMP_0044", + condition = "_LEGACY_ERROR_TEMP_0044", parameters = Map.empty, context = ExpectedContext(sqlText, 0, 30)) } @@ -504,14 +504,14 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { test("SPARK-47765: set collation") { Seq("UNICODE", "UNICODE_CI", "utf8_lcase", "utf8_binary").foreach { collation => sql(s"set collation $collation") - assert(spark.conf.get(SQLConf.DEFAULT_COLLATION) === collation.toUpperCase(Locale.ROOT)) + assert(sqlConf.getConf(SQLConf.DEFAULT_COLLATION) === collation.toUpperCase(Locale.ROOT)) } checkError( exception = intercept[SparkIllegalArgumentException] { sql(s"SET COLLATION unicode_c").collect() }, - errorClass = "INVALID_CONF_VALUE.DEFAULT_COLLATION", + condition = "INVALID_CONF_VALUE.DEFAULT_COLLATION", parameters = Map( "confValue" -> "UNICODE_C", "confName" -> "spark.sql.session.collation.default", @@ -522,7 +522,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { test("SPARK-43028: config not found error") { checkError( exception = intercept[SparkNoSuchElementException](spark.conf.get("some.conf")), - errorClass = "SQL_CONF_NOT_FOUND", + condition = "SQL_CONF_NOT_FOUND", parameters = Map("sqlConf" -> "\"some.conf\"")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SharedStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SharedStateSuite.scala index d3154d0125af8..b323c6366f520 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SharedStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SharedStateSuite.scala @@ -61,7 +61,7 @@ class SharedStateSuite extends SharedSparkSession { exception = intercept[SparkException] { spark.sharedState.externalCatalog }, - errorClass = "DEFAULT_DATABASE_NOT_EXISTS", + condition = "DEFAULT_DATABASE_NOT_EXISTS", parameters = Map("defaultDatabase" -> "default_database_not_exists") ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 77997f95188c0..e8a8a0ae47bfd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1520,7 +1520,7 @@ class JDBCSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkSQLException] { spark.read.jdbc(urlWithUserAndPass, "TEST.ARRAY_TABLE", new Properties()).collect() }, - errorClass = "UNRECOGNIZED_SQL_TYPE", + condition = "UNRECOGNIZED_SQL_TYPE", parameters = Map("typeName" -> "INTEGER ARRAY", "jdbcType" -> "ARRAY")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 780cc86bb6a61..054c7e644ff55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -399,7 +399,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel exception = intercept[AnalysisException] { df.collect() }, - errorClass = "NULL_DATA_SOURCE_OPTION", + condition = "NULL_DATA_SOURCE_OPTION", parameters = Map( "option" -> "pushDownOffset") ) @@ -2943,7 +2943,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel exception = intercept[AnalysisException] { checkAnswer(sql("SELECT h2.test.my_avg2(id) FROM h2.test.people"), Seq.empty) }, - errorClass = "UNRESOLVED_ROUTINE", + condition = "UNRESOLVED_ROUTINE", parameters = Map( "routineName" -> "`h2`.`test`.`my_avg2`", "searchPath" -> "[`system`.`builtin`, `system`.`session`, `h2`.`default`]"), @@ -2955,7 +2955,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel exception = intercept[AnalysisException] { checkAnswer(sql("SELECT h2.my_avg2(id) FROM h2.test.people"), Seq.empty) }, - errorClass = "UNRESOLVED_ROUTINE", + condition = "UNRESOLVED_ROUTINE", parameters = Map( "routineName" -> "`h2`.`my_avg2`", "searchPath" -> "[`system`.`builtin`, `system`.`session`, `h2`.`default`]"), @@ -3038,7 +3038,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel exception = intercept[IndexAlreadyExistsException] { sql(s"CREATE INDEX people_index ON TABLE h2.test.people (id)") }, - errorClass = "INDEX_ALREADY_EXISTS", + condition = "INDEX_ALREADY_EXISTS", parameters = Map( "indexName" -> "`people_index`", "tableName" -> "`test`.`people`" @@ -3056,7 +3056,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel exception = intercept[NoSuchIndexException] { sql(s"DROP INDEX people_index ON TABLE h2.test.people") }, - errorClass = "INDEX_NOT_FOUND", + condition = "INDEX_NOT_FOUND", parameters = Map("indexName" -> "`people_index`", "tableName" -> "`test`.`people`") ) assert(jdbcTable.indexExists("people_index") == false) @@ -3073,7 +3073,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel exception = intercept[AnalysisException] { sql("SELECT * FROM h2.test.people where h2.db_name.schema_name.function_name()") }, - errorClass = "IDENTIFIER_TOO_MANY_NAME_PARTS", + condition = "IDENTIFIER_TOO_MANY_NAME_PARTS", sqlState = "42601", parameters = Map("identifier" -> "`db_name`.`schema_name`.`function_name`") ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 76a092b552f98..e7044ea50f54f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -191,7 +191,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { exception = intercept[AnalysisException] { df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties()) }, - errorClass = "_LEGACY_ERROR_TEMP_1156", + condition = "_LEGACY_ERROR_TEMP_1156", parameters = Map( "colName" -> "NAME", "tableSchema" -> @@ -224,7 +224,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { df3.write.mode(SaveMode.Overwrite).option("truncate", true) .jdbc(url1, "TEST.TRUNCATETEST", properties) }, - errorClass = "_LEGACY_ERROR_TEMP_1156", + condition = "_LEGACY_ERROR_TEMP_1156", parameters = Map( "colName" -> "seq", "tableSchema" -> @@ -256,7 +256,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { exception = intercept[AnalysisException] { df2.write.mode(SaveMode.Append).jdbc(url, "TEST.INCOMPATIBLETEST", new Properties()) }, - errorClass = "_LEGACY_ERROR_TEMP_1156", + condition = "_LEGACY_ERROR_TEMP_1156", parameters = Map( "colName" -> "seq", "tableSchema" -> @@ -507,7 +507,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { .option("createTableColumnTypes", "name CLOB(2000)") .jdbc(url1, "TEST.USERDBTYPETEST", properties) }, - errorClass = "UNSUPPORTED_DATATYPE", + condition = "UNSUPPORTED_DATATYPE", parameters = Map("typeName" -> "\"CLOB(2000)\"")) } @@ -519,7 +519,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { .option("createTableColumnTypes", "`name char(20)") // incorrectly quoted column .jdbc(url1, "TEST.USERDBTYPETEST", properties) }, - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'`'", "hint" -> "")) } @@ -533,7 +533,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { } checkError( exception = e, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`name`")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 5c36f9e19e6d9..83d8191d01ec1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -45,23 +45,17 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi override def output: Seq[Attribute] = Seq.empty } - case class TestWhileCondition( + case class TestLoopCondition( condVal: Boolean, reps: Int, description: String) extends SingleStatementExec( parsedPlan = DummyLogicalPlan(), Origin(startIndex = Some(0), stopIndex = Some(description.length)), isInternal = false) - case class TestWhile( - condition: TestWhileCondition, - body: CompoundBodyExec) - extends WhileStatementExec(condition, body, spark) { - + class LoopBooleanConditionEvaluator(condition: TestLoopCondition) { private var callCount: Int = 0 - override def evaluateBooleanCondition( - session: SparkSession, - statement: LeafStatementExec): Boolean = { + def evaluateLoopBooleanCondition(): Boolean = { if (callCount < condition.reps) { callCount += 1 condition.condVal @@ -72,11 +66,39 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } } + case class TestWhile( + condition: TestLoopCondition, + body: CompoundBodyExec, + label: Option[String] = None) + extends WhileStatementExec(condition, body, label, spark) { + + private val evaluator = new LoopBooleanConditionEvaluator(condition) + + override def evaluateBooleanCondition( + session: SparkSession, + statement: LeafStatementExec): Boolean = evaluator.evaluateLoopBooleanCondition() + } + + case class TestRepeat( + condition: TestLoopCondition, + body: CompoundBodyExec, + label: Option[String] = None) + extends RepeatStatementExec(condition, body, label, spark) { + + private val evaluator = new LoopBooleanConditionEvaluator(condition) + + override def evaluateBooleanCondition( + session: SparkSession, + statement: LeafStatementExec): Boolean = evaluator.evaluateLoopBooleanCondition() + } + private def extractStatementValue(statement: CompoundStatementExec): String = statement match { case TestLeafStatement(testVal) => testVal case TestIfElseCondition(_, description) => description - case TestWhileCondition(_, _, description) => description + case TestLoopCondition(_, _, description) => description + case leaveStmt: LeaveStatementExec => leaveStmt.label + case iterateStmt: IterateStatementExec => iterateStmt.label case _ => fail("Unexpected statement type") } @@ -262,7 +284,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("while - doesn't enter body") { val iter = new CompoundBodyExec(Seq( TestWhile( - condition = TestWhileCondition(condVal = true, reps = 0, description = "con1"), + condition = TestLoopCondition(condVal = true, reps = 0, description = "con1"), body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))) ) )).getTreeIterator @@ -273,7 +295,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("while - enters body once") { val iter = new CompoundBodyExec(Seq( TestWhile( - condition = TestWhileCondition(condVal = true, reps = 1, description = "con1"), + condition = TestLoopCondition(condVal = true, reps = 1, description = "con1"), body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))) ) )).getTreeIterator @@ -284,7 +306,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("while - enters body with multiple statements multiple times") { val iter = new CompoundBodyExec(Seq( TestWhile( - condition = TestWhileCondition(condVal = true, reps = 2, description = "con1"), + condition = TestLoopCondition(condVal = true, reps = 2, description = "con1"), body = new CompoundBodyExec(Seq( TestLeafStatement("statement1"), TestLeafStatement("statement2"))) @@ -298,10 +320,10 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("nested while - 2 times outer 2 times inner") { val iter = new CompoundBodyExec(Seq( TestWhile( - condition = TestWhileCondition(condVal = true, reps = 2, description = "con1"), + condition = TestLoopCondition(condVal = true, reps = 2, description = "con1"), body = new CompoundBodyExec(Seq( TestWhile( - condition = TestWhileCondition(condVal = true, reps = 2, description = "con2"), + condition = TestLoopCondition(condVal = true, reps = 2, description = "con2"), body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))) )) ) @@ -314,4 +336,337 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi "con2", "body1", "con2", "con1")) } + test("repeat - true condition") { + val iter = new CompoundBodyExec(Seq( + TestRepeat( + condition = TestLoopCondition(condVal = false, reps = 0, description = "con1"), + body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("body1", "con1")) + } + + test("repeat - condition false once") { + val iter = new CompoundBodyExec(Seq( + TestRepeat( + condition = TestLoopCondition(condVal = false, reps = 1, description = "con1"), + body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("body1", "con1", "body1", "con1")) + } + + test("repeat - enters body with multiple statements multiple times") { + val iter = new CompoundBodyExec(Seq( + TestRepeat( + condition = TestLoopCondition(condVal = false, reps = 2, description = "con1"), + body = new CompoundBodyExec(Seq( + TestLeafStatement("statement1"), + TestLeafStatement("statement2"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("statement1", "statement2", "con1", "statement1", "statement2", + "con1", "statement1", "statement2", "con1")) + } + + test("nested repeat") { + val iter = new CompoundBodyExec(Seq( + TestRepeat( + condition = TestLoopCondition(condVal = false, reps = 2, description = "con1"), + body = new CompoundBodyExec(Seq( + TestRepeat( + condition = TestLoopCondition(condVal = false, reps = 2, description = "con2"), + body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))) + )) + ) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("body1", "con2", "body1", + "con2", "body1", "con2", + "con1", "body1", "con2", + "body1", "con2", "body1", + "con2", "con1", "body1", + "con2", "body1", "con2", + "body1", "con2", "con1")) + } + + test("leave compound block") { + val iter = new CompoundBodyExec( + statements = Seq( + TestLeafStatement("one"), + new LeaveStatementExec("lbl") + ), + label = Some("lbl") + ).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("one", "lbl")) + } + + test("leave while loop") { + val iter = new CompoundBodyExec( + statements = Seq( + TestWhile( + condition = TestLoopCondition(condVal = true, reps = 2, description = "con1"), + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new LeaveStatementExec("lbl")) + ), + label = Some("lbl") + ) + ) + ).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("con1", "body1", "lbl")) + } + + test("leave repeat loop") { + val iter = new CompoundBodyExec( + statements = Seq( + TestRepeat( + condition = TestLoopCondition(condVal = false, reps = 2, description = "con1"), + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new LeaveStatementExec("lbl")) + ), + label = Some("lbl") + ) + ) + ).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("body1", "lbl")) + } + + test("iterate while loop") { + val iter = new CompoundBodyExec( + statements = Seq( + TestWhile( + condition = TestLoopCondition(condVal = true, reps = 2, description = "con1"), + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new IterateStatementExec("lbl"), + TestLeafStatement("body2")) + ), + label = Some("lbl") + ) + ) + ).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("con1", "body1", "lbl", "con1", "body1", "lbl", "con1")) + } + + test("iterate repeat loop") { + val iter = new CompoundBodyExec( + statements = Seq( + TestRepeat( + condition = TestLoopCondition(condVal = false, reps = 2, description = "con1"), + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new IterateStatementExec("lbl"), + TestLeafStatement("body2")) + ), + label = Some("lbl") + ) + ) + ).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert( + statements === Seq("body1", "lbl", "con1", "body1", "lbl", "con1", "body1", "lbl", "con1")) + } + + test("leave outer loop from nested while loop") { + val iter = new CompoundBodyExec( + statements = Seq( + TestWhile( + condition = TestLoopCondition(condVal = true, reps = 2, description = "con1"), + body = new CompoundBodyExec(Seq( + TestWhile( + condition = TestLoopCondition(condVal = true, reps = 2, description = "con2"), + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new LeaveStatementExec("lbl")) + ), + label = Some("lbl2") + ) + )), + label = Some("lbl") + ) + ) + ).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("con1", "con2", "body1", "lbl")) + } + + test("leave outer loop from nested repeat loop") { + val iter = new CompoundBodyExec( + statements = Seq( + TestRepeat( + condition = TestLoopCondition(condVal = false, reps = 2, description = "con1"), + body = new CompoundBodyExec(Seq( + TestRepeat( + condition = TestLoopCondition(condVal = false, reps = 2, description = "con2"), + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new LeaveStatementExec("lbl")) + ), + label = Some("lbl2") + ) + )), + label = Some("lbl") + ) + ) + ).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("body1", "lbl")) + } + + test("iterate outer loop from nested while loop") { + val iter = new CompoundBodyExec( + statements = Seq( + TestWhile( + condition = TestLoopCondition(condVal = true, reps = 2, description = "con1"), + body = new CompoundBodyExec(Seq( + TestWhile( + condition = TestLoopCondition(condVal = true, reps = 2, description = "con2"), + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new IterateStatementExec("lbl"), + TestLeafStatement("body2")) + ), + label = Some("lbl2") + ) + )), + label = Some("lbl") + ) + ) + ).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "con1", "con2", "body1", "lbl", + "con1", "con2", "body1", "lbl", + "con1")) + } + + test("iterate outer loop from nested repeat loop") { + val iter = new CompoundBodyExec( + statements = Seq( + TestRepeat( + condition = TestLoopCondition(condVal = false, reps = 2, description = "con1"), + body = new CompoundBodyExec(Seq( + TestRepeat( + condition = TestLoopCondition(condVal = false, reps = 2, description = "con2"), + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new IterateStatementExec("lbl"), + TestLeafStatement("body2")) + ), + label = Some("lbl2") + ) + )), + label = Some("lbl") + ) + ) + ).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "body1", "lbl", "con1", + "body1", "lbl", "con1", + "body1", "lbl", "con1")) + } + + test("searched case - enter first WHEN clause") { + val iter = new CompoundBodyExec(Seq( + new CaseStatementExec( + conditions = Seq( + TestIfElseCondition(condVal = true, description = "con1"), + TestIfElseCondition(condVal = false, description = "con2") + ), + conditionalBodies = Seq( + new CompoundBodyExec(Seq(TestLeafStatement("body1"))), + new CompoundBodyExec(Seq(TestLeafStatement("body2"))) + ), + elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body3")))), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("con1", "body1")) + } + + test("searched case - enter body of the ELSE clause") { + val iter = new CompoundBodyExec(Seq( + new CaseStatementExec( + conditions = Seq( + TestIfElseCondition(condVal = false, description = "con1") + ), + conditionalBodies = Seq( + new CompoundBodyExec(Seq(TestLeafStatement("body1"))) + ), + elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body2")))), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("con1", "body2")) + } + + test("searched case - enter second WHEN clause") { + val iter = new CompoundBodyExec(Seq( + new CaseStatementExec( + conditions = Seq( + TestIfElseCondition(condVal = false, description = "con1"), + TestIfElseCondition(condVal = true, description = "con2") + ), + conditionalBodies = Seq( + new CompoundBodyExec(Seq(TestLeafStatement("body1"))), + new CompoundBodyExec(Seq(TestLeafStatement("body2"))) + ), + elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body3")))), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("con1", "con2", "body2")) + } + + test("searched case - without else (successful check)") { + val iter = new CompoundBodyExec(Seq( + new CaseStatementExec( + conditions = Seq( + TestIfElseCondition(condVal = false, description = "con1"), + TestIfElseCondition(condVal = true, description = "con2") + ), + conditionalBodies = Seq( + new CompoundBodyExec(Seq(TestLeafStatement("body1"))), + new CompoundBodyExec(Seq(TestLeafStatement("body2"))) + ), + elseBody = None, + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("con1", "con2", "body2")) + } + + test("searched case - without else (unsuccessful checks)") { + val iter = new CompoundBodyExec(Seq( + new CaseStatementExec( + conditions = Seq( + TestIfElseCondition(condVal = false, description = "con1"), + TestIfElseCondition(condVal = false, description = "con2") + ), + conditionalBodies = Seq( + new CompoundBodyExec(Seq(TestLeafStatement("body1"))), + new CompoundBodyExec(Seq(TestLeafStatement("body2"))) + ), + elseBody = None, + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("con1", "con2")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 592516de84c17..ac190eb48d1f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.scripting -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkNumberFormatException} import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row} import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.exceptions.SqlScriptingException +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession /** @@ -190,7 +191,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } checkError( exception = e, - errorClass = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", sqlState = "42703", parameters = Map("objectName" -> s"`$varName`"), context = ExpectedContext( @@ -368,6 +369,391 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } + test("searched case") { + val commands = + """ + |BEGIN + | CASE + | WHEN 1 = 1 THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("searched case nested") { + val commands = + """ + |BEGIN + | CASE + | WHEN 1=1 THEN + | CASE + | WHEN 2=1 THEN + | SELECT 41; + | ELSE + | SELECT 42; + | END CASE; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("searched case second case") { + val commands = + """ + |BEGIN + | CASE + | WHEN 1 = (SELECT 2) THEN + | SELECT 1; + | WHEN 2 = 2 THEN + | SELECT 42; + | WHEN (SELECT * FROM t) THEN + | SELECT * FROM b; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("searched case going in else") { + val commands = + """ + |BEGIN + | CASE + | WHEN 2 = 1 THEN + | SELECT 1; + | WHEN 3 IN (1,2) THEN + | SELECT 2; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + + test("searched case with count") { + withTable("t") { + val commands = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |INSERT INTO t VALUES (1, 'a', 1.0); + |CASE + | WHEN (SELECT COUNT(*) > 2 FROM t) THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + + val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row], Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + } + + test("searched case else with count") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + | INSERT INTO t VALUES (1, 'a', 1.0); + | INSERT INTO t VALUES (1, 'a', 1.0); + | CASE + | WHEN (SELECT COUNT(*) > 2 FROM t) THEN + | SELECT 42; + | WHEN (SELECT COUNT(*) > 1 FROM t) THEN + | SELECT 43; + | ELSE + | SELECT 44; + | END CASE; + |END + |""".stripMargin + + val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row], Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + } + + test("searched case no cases matched no else") { + val commands = + """ + |BEGIN + | CASE + | WHEN 1 = 2 THEN + | SELECT 42; + | WHEN 1 = 3 THEN + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq() + verifySqlScriptResult(commands, expected) + } + + test("searched case when evaluates to null") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a BOOLEAN) USING parquet; + | CASE + | WHEN (SELECT * FROM t) THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + + checkError( + exception = intercept[SqlScriptingException] ( + runSqlScript(commands) + ), + condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW", + parameters = Map("invalidStatement" -> "(SELECT * FROM T)") + ) + } + } + + test("searched case with non boolean condition - constant") { + val commands = + """ + |BEGIN + | CASE + | WHEN 1 THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + + checkError( + exception = intercept[SqlScriptingException] ( + runSqlScript(commands) + ), + condition = "INVALID_BOOLEAN_STATEMENT", + parameters = Map("invalidStatement" -> "1") + ) + } + + test("searched case with too many rows in subquery condition") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a BOOLEAN) USING parquet; + | INSERT INTO t VALUES (true); + | INSERT INTO t VALUES (true); + | CASE + | WHEN (SELECT * FROM t) THEN + | SELECT 1; + | END CASE; + |END + |""".stripMargin + + checkError( + exception = intercept[SparkException] ( + runSqlScript(commands) + ), + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS", + parameters = Map.empty, + context = ExpectedContext(fragment = "(SELECT * FROM t)", start = 124, stop = 140) + ) + } + } + + test("simple case") { + val commands = + """ + |BEGIN + | CASE 1 + | WHEN 1 THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("simple case nested") { + val commands = + """ + |BEGIN + | CASE 1 + | WHEN 1 THEN + | CASE 2 + | WHEN (SELECT 3) THEN + | SELECT 41; + | ELSE + | SELECT 42; + | END CASE; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("simple case second case") { + val commands = + """ + |BEGIN + | CASE (SELECT 2) + | WHEN 1 THEN + | SELECT 1; + | WHEN 2 THEN + | SELECT 42; + | WHEN (SELECT * FROM t) THEN + | SELECT * FROM b; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("simple case going in else") { + val commands = + """ + |BEGIN + | CASE 1 + | WHEN 2 THEN + | SELECT 1; + | WHEN 3 THEN + | SELECT 2; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + + test("simple case with count") { + withTable("t") { + val commands = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |INSERT INTO t VALUES (1, 'a', 1.0); + |CASE (SELECT COUNT(*) FROM t) + | WHEN 1 THEN + | SELECT 41; + | WHEN 2 THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + + val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row], Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + } + + test("simple case else with count") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + | INSERT INTO t VALUES (1, 'a', 1.0); + | INSERT INTO t VALUES (2, 'b', 2.0); + | CASE (SELECT COUNT(*) FROM t) + | WHEN 1 THEN + | SELECT 42; + | WHEN 3 THEN + | SELECT 43; + | ELSE + | SELECT 44; + | END CASE; + |END + |""".stripMargin + + val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row], Seq(Row(44))) + verifySqlScriptResult(commands, expected) + } + } + + test("simple case no cases matched no else") { + val commands = + """ + |BEGIN + | CASE 1 + | WHEN 2 THEN + | SELECT 42; + | WHEN 3 THEN + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq() + verifySqlScriptResult(commands, expected) + } + + test("simple case mismatched types") { + val commands = + """ + |BEGIN + | CASE 1 + | WHEN "one" THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + checkError( + exception = intercept[SparkNumberFormatException]( + runSqlScript(commands) + ), + condition = "CAST_INVALID_INPUT", + parameters = Map( + "expression" -> "'one'", + "sourceType" -> "\"STRING\"", + "targetType" -> "\"BIGINT\""), + context = ExpectedContext(fragment = "\"one\"", start = 23, stop = 27)) + } + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkError( + exception = intercept[SqlScriptingException]( + runSqlScript(commands) + ), + condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW", + parameters = Map("invalidStatement" -> "\"ONE\"")) + } + } + + test("simple case compare with null") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a INT) USING parquet; + | CASE (SELECT COUNT(*) FROM t) + | WHEN 1 THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + + val expected = Seq(Seq.empty[Row], Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + } + test("if's condition must be a boolean statement") { withTable("t") { val commands = @@ -378,13 +764,16 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { | END IF; |END |""".stripMargin + val exception = intercept[SqlScriptingException] { + runSqlScript(commands) + } checkError( - exception = intercept[SqlScriptingException] ( - runSqlScript(commands) - ), - errorClass = "INVALID_BOOLEAN_STATEMENT", + exception = exception, + condition = "INVALID_BOOLEAN_STATEMENT", parameters = Map("invalidStatement" -> "1") ) + assert(exception.origin.line.isDefined) + assert(exception.origin.line.get == 3) } } @@ -400,13 +789,16 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { | END IF; |END |""".stripMargin + val exception = intercept[SqlScriptingException] { + runSqlScript(commands1) + } checkError( - exception = intercept[SqlScriptingException] ( - runSqlScript(commands1) - ), - errorClass = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW", + exception = exception, + condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW", parameters = Map("invalidStatement" -> "(SELECT * FROM T1)") ) + assert(exception.origin.line.isDefined) + assert(exception.origin.line.get == 4) // too many rows ( > 1 ) val commands2 = @@ -424,7 +816,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkException] ( runSqlScript(commands2) ), - errorClass = "SCALAR_SUBQUERY_TOO_MANY_ROWS", + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS", parameters = Map.empty, context = ExpectedContext(fragment = "(SELECT * FROM t2)", start = 121, stop = 138) ) @@ -536,4 +928,459 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(commands, expected) } } + + test("repeat") { + val commands = + """ + |BEGIN + | DECLARE i = 0; + | REPEAT + | SELECT i; + | SET VAR i = i + 1; + | UNTIL + | i = 3 + | END REPEAT; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare i + Seq(Row(0)), // select i + Seq.empty[Row], // set i + Seq(Row(1)), // select i + Seq.empty[Row], // set i + Seq(Row(2)), // select i + Seq.empty[Row], // set i + Seq.empty[Row] // drop var + ) + verifySqlScriptResult(commands, expected) + } + + test("repeat: enters body only once") { + val commands = + """ + |BEGIN + | DECLARE i = 3; + | REPEAT + | SELECT i; + | SET VAR i = i + 1; + | UNTIL + | 1 = 1 + | END REPEAT; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare i + Seq(Row(3)), // select i + Seq.empty[Row], // set i + Seq.empty[Row] // drop i + ) + verifySqlScriptResult(commands, expected) + } + + test("nested repeat") { + val commands = + """ + |BEGIN + | DECLARE i = 0; + | DECLARE j = 0; + | REPEAT + | SET VAR j = 0; + | REPEAT + | SELECT i, j; + | SET VAR j = j + 1; + | UNTIL j >= 2 + | END REPEAT; + | SET VAR i = i + 1; + | UNTIL i >= 2 + | END REPEAT; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare i + Seq.empty[Row], // declare j + Seq.empty[Row], // set j to 0 + Seq(Row(0, 0)), // select i, j + Seq.empty[Row], // increase j + Seq(Row(0, 1)), // select i, j + Seq.empty[Row], // increase j + Seq.empty[Row], // increase i + Seq.empty[Row], // set j to 0 + Seq(Row(1, 0)), // select i, j + Seq.empty[Row], // increase j + Seq(Row(1, 1)), // select i, j + Seq.empty[Row], // increase j + Seq.empty[Row], // increase i + Seq.empty[Row], // drop j + Seq.empty[Row] // drop i + ) + verifySqlScriptResult(commands, expected) + } + + test("repeat with count") { + withTable("t") { + val commands = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |REPEAT + | SELECT 42; + | INSERT INTO t VALUES (1, 'a', 1.0); + |UNTIL (SELECT COUNT(*) >= 2 FROM t) + |END REPEAT; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq(Row(42)), // select + Seq.empty[Row], // insert + Seq(Row(42)), // select + Seq.empty[Row] // insert + ) + verifySqlScriptResult(commands, expected) + } + } + + test("repeat with non boolean condition - constant") { + val commands = + """ + |BEGIN + | DECLARE i = 0; + | REPEAT + | SELECT i; + | SET VAR i = i + 1; + | UNTIL + | 1 + | END REPEAT; + |END + |""".stripMargin + + checkError( + exception = intercept[SqlScriptingException] ( + runSqlScript(commands) + ), + condition = "INVALID_BOOLEAN_STATEMENT", + parameters = Map("invalidStatement" -> "1") + ) + } + + test("repeat with empty subquery condition") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a BOOLEAN) USING parquet; + | REPEAT + | SELECT 1; + | UNTIL + | (SELECT * FROM t) + | END REPEAT; + |END + |""".stripMargin + + checkError( + exception = intercept[SqlScriptingException] ( + runSqlScript(commands) + ), + condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW", + parameters = Map("invalidStatement" -> "(SELECT * FROM T)") + ) + } + } + + test("repeat with too many rows in subquery condition") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a BOOLEAN) USING parquet; + | INSERT INTO t VALUES (true); + | INSERT INTO t VALUES (true); + | REPEAT + | SELECT 1; + | UNTIL + | (SELECT * FROM t) + | END REPEAT; + |END + |""".stripMargin + + checkError( + exception = intercept[SparkException] ( + runSqlScript(commands) + ), + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS", + parameters = Map.empty, + context = ExpectedContext(fragment = "(SELECT * FROM t)", start = 141, stop = 157) + ) + } + } + + test("leave compound block") { + val sqlScriptText = + """ + |lbl: BEGIN + | SELECT 1; + | LEAVE lbl; + |END""".stripMargin + val expected = Seq( + Seq(Row(1)) // select + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("leave while loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: WHILE 1 = 1 DO + | SELECT 1; + | LEAVE lbl; + | END WHILE; + |END""".stripMargin + val expected = Seq( + Seq(Row(1)) // select + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("leave repeat loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: REPEAT + | SELECT 1; + | LEAVE lbl; + | UNTIL 1 = 2 + | END REPEAT; + |END""".stripMargin + val expected = Seq( + Seq(Row(1)) // select 1 + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("iterate compound block - should fail") { + val sqlScriptText = + """ + |lbl: BEGIN + | SELECT 1; + | ITERATE lbl; + |END""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + runSqlScript(sqlScriptText) + }, + condition = "INVALID_LABEL_USAGE.ITERATE_IN_COMPOUND", + parameters = Map("labelName" -> "LBL")) + } + + test("iterate while loop") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: WHILE x < 2 DO + | SET x = x + 1; + | ITERATE lbl; + | SET x = x + 2; + | END WHILE; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare + Seq.empty[Row], // set x = 0 + Seq.empty[Row], // set x = 1 + Seq.empty[Row], // set x = 2 + Seq(Row(2)), // select + Seq.empty[Row] // drop + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("iterate repeat loop") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: REPEAT + | SET x = x + 1; + | ITERATE lbl; + | SET x = x + 2; + | UNTIL x > 1 + | END REPEAT; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare + Seq.empty[Row], // set x = 0 + Seq.empty[Row], // set x = 1 + Seq.empty[Row], // set x = 2 + Seq(Row(2)), // select x + Seq.empty[Row] // drop + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("leave with wrong label - should fail") { + val sqlScriptText = + """ + |lbl: BEGIN + | SELECT 1; + | LEAVE randomlbl; + |END""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + runSqlScript(sqlScriptText) + }, + condition = "INVALID_LABEL_USAGE.DOES_NOT_EXIST", + parameters = Map("labelName" -> "RANDOMLBL", "statementType" -> "LEAVE")) + } + + test("iterate with wrong label - should fail") { + val sqlScriptText = + """ + |lbl: BEGIN + | SELECT 1; + | ITERATE randomlbl; + |END""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + runSqlScript(sqlScriptText) + }, + condition = "INVALID_LABEL_USAGE.DOES_NOT_EXIST", + parameters = Map("labelName" -> "RANDOMLBL", "statementType" -> "ITERATE")) + } + + test("leave outer loop from nested repeat loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: REPEAT + | lbl2: REPEAT + | SELECT 1; + | LEAVE lbl; + | UNTIL 1 = 2 + | END REPEAT; + | UNTIL 1 = 2 + | END REPEAT; + |END""".stripMargin + val expected = Seq( + Seq(Row(1)) // select 1 + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("leave outer loop from nested while loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: WHILE 1 = 1 DO + | lbl2: WHILE 2 = 2 DO + | SELECT 1; + | LEAVE lbl; + | END WHILE; + | END WHILE; + |END""".stripMargin + val expected = Seq( + Seq(Row(1)) // select + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("iterate outer loop from nested while loop") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: WHILE x < 2 DO + | SET x = x + 1; + | lbl2: WHILE 2 = 2 DO + | SELECT 1; + | ITERATE lbl; + | END WHILE; + | END WHILE; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare + Seq.empty[Row], // set x = 0 + Seq.empty[Row], // set x = 1 + Seq(Row(1)), // select 1 + Seq.empty[Row], // set x = 2 + Seq(Row(1)), // select 1 + Seq(Row(2)), // select x + Seq.empty[Row] // drop + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("nested compounds in loop - leave in inner compound") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: WHILE x < 2 DO + | SET x = x + 1; + | BEGIN + | SELECT 1; + | lbl2: BEGIN + | SELECT 2; + | LEAVE lbl2; + | SELECT 3; + | END; + | END; + | END WHILE; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare + Seq.empty[Row], // set x = 0 + Seq.empty[Row], // set x = 1 + Seq(Row(1)), // select 1 + Seq(Row(2)), // select 2 + Seq.empty[Row], // set x = 2 + Seq(Row(1)), // select 1 + Seq(Row(2)), // select 2 + Seq(Row(2)), // select x + Seq.empty[Row] // drop + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("iterate outer loop from nested repeat loop") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: REPEAT + | SET x = x + 1; + | lbl2: REPEAT + | SELECT 1; + | ITERATE lbl; + | UNTIL 1 = 2 + | END REPEAT; + | UNTIL x > 1 + | END REPEAT; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare + Seq.empty[Row], // set x = 0 + Seq.empty[Row], // set x = 1 + Seq(Row(1)), // select 1 + Seq.empty[Row], // set x = 2 + Seq(Row(1)), // select 1 + Seq(Row(2)), // select x + Seq.empty[Row] // drop + ) + verifySqlScriptResult(sqlScriptText, expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 8b11e0c69fa70..24732223c6698 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -54,7 +54,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti protected override def beforeAll(): Unit = { super.beforeAll() - spark.conf.set(SQLConf.LEGACY_BUCKETED_TABLE_SCAN_OUTPUT_ORDERING, true) + spark.conf.set(SQLConf.LEGACY_BUCKETED_TABLE_SCAN_OUTPUT_ORDERING.key, true) } protected override def afterAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 4f1b7d363a124..b473716b33fca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -91,7 +91,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { df.write.sortBy("j").saveAsTable("tt") }, - errorClass = "SORT_BY_WITHOUT_BUCKETING", + condition = "SORT_BY_WITHOUT_BUCKETING", parameters = Map.empty) } @@ -106,7 +106,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { df.write.bucketBy(2, "i").parquet("/tmp/path") }, - errorClass = "_LEGACY_ERROR_TEMP_1312", + condition = "_LEGACY_ERROR_TEMP_1312", parameters = Map("operation" -> "save")) } @@ -116,7 +116,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { df.write.bucketBy(2, "i").sortBy("i").parquet("/tmp/path") }, - errorClass = "_LEGACY_ERROR_TEMP_1313", + condition = "_LEGACY_ERROR_TEMP_1313", parameters = Map("operation" -> "save")) } @@ -126,7 +126,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { df.write.bucketBy(2, "i").insertInto("tt") }, - errorClass = "_LEGACY_ERROR_TEMP_1312", + condition = "_LEGACY_ERROR_TEMP_1312", parameters = Map("operation" -> "insertInto")) } @@ -136,7 +136,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { exception = intercept[AnalysisException] { df.write.bucketBy(2, "i").sortBy("i").insertInto("tt") }, - errorClass = "_LEGACY_ERROR_TEMP_1313", + condition = "_LEGACY_ERROR_TEMP_1313", parameters = Map("operation" -> "insertInto")) } @@ -252,7 +252,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { .bucketBy(8, "j", "k") .sortBy("k") .saveAsTable("bucketed_table")), - errorClass = "_LEGACY_ERROR_TEMP_1166", + condition = "_LEGACY_ERROR_TEMP_1166", parameters = Map("bucketCol" -> "j", "normalizedPartCols" -> "i, j")) checkError( @@ -261,7 +261,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { .bucketBy(8, "k") .sortBy("i") .saveAsTable("bucketed_table")), - errorClass = "_LEGACY_ERROR_TEMP_1167", + condition = "_LEGACY_ERROR_TEMP_1167", parameters = Map("sortCol" -> "i", "normalizedPartCols" -> "i, j")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 6f897a9c0b7b0..95c2fcbd7b5d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -164,7 +164,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSparkSession { exception = intercept[ParseException] { sql(sqlText) }, - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map( "message" -> "CREATE TEMPORARY TABLE ... AS ..., use CREATE TEMPORARY VIEW instead"), context = ExpectedContext( @@ -291,7 +291,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSparkSession { exception = intercept[ParseException] { sql(sqlText) }, - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map( "message" -> "Schema may not be specified in a Create Table As Select (CTAS) statement"), context = ExpectedContext( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala index b6fb83fa5b876..43dfed277cbe7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -30,7 +30,7 @@ class DDLSourceLoadSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { spark.read.format("Fluet da Bomb").load() }, - errorClass = "_LEGACY_ERROR_TEMP_1141", + condition = "_LEGACY_ERROR_TEMP_1141", parameters = Map( "provider" -> "Fluet da Bomb", "sourceNames" -> ("org.apache.spark.sql.sources.FakeSourceOne, " + @@ -49,7 +49,7 @@ class DDLSourceLoadSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { spark.read.format("Fake external source").load() }, - errorClass = "_LEGACY_ERROR_TEMP_1141", + condition = "_LEGACY_ERROR_TEMP_1141", parameters = Map( "provider" -> "Fake external source", "sourceNames" -> ("org.apache.fakesource.FakeExternalSourceOne, " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index e3e385e9d1810..41447d8af5740 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -117,7 +117,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("INSERT INTO TABLE t1 SELECT a FROM t2") }, - errorClass = "UNSUPPORTED_INSERT.NOT_ALLOWED", + condition = "UNSUPPORTED_INSERT.NOT_ALLOWED", parameters = Map("relationId" -> "`SimpleScan(1,10)`") ) } @@ -131,7 +131,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("INSERT INTO TABLE t1 SELECT a FROM t1") }, - errorClass = "UNSUPPORTED_INSERT.RDD_BASED", + condition = "UNSUPPORTED_INSERT.RDD_BASED", parameters = Map.empty ) } @@ -151,7 +151,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("INSERT INTO TABLE t1 SELECT * FROM t1") }, - errorClass = "UNSUPPORTED_INSERT.READ_FROM", + condition = "UNSUPPORTED_INSERT.READ_FROM", parameters = Map("relationId" -> "`SimpleScan(1,10)`") ) } @@ -293,7 +293,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jsonTable") }, - errorClass = "UNSUPPORTED_OVERWRITE.PATH", + condition = "UNSUPPORTED_OVERWRITE.PATH", parameters = Map("path" -> ".*")) } @@ -338,7 +338,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { |SELECT i + 1, part2 FROM insertTable """.stripMargin) }, - errorClass = "UNSUPPORTED_OVERWRITE.TABLE", + condition = "UNSUPPORTED_OVERWRITE.TABLE", parameters = Map("table" -> "`spark_catalog`.`default`.`inserttable`")) } } @@ -418,7 +418,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("INSERT OVERWRITE TABLE oneToTen SELECT CAST(a AS INT) FROM jt") }, - errorClass = "UNSUPPORTED_INSERT.NOT_ALLOWED", + condition = "UNSUPPORTED_INSERT.NOT_ALLOWED", parameters = Map("relationId" -> "`SimpleScan(1,10)`")) spark.catalog.dropTempView("oneToTen") @@ -527,7 +527,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { |SELECT 1, 2 """.stripMargin) }, - errorClass = "NOT_SUPPORTED_COMMAND_WITHOUT_HIVE_SUPPORT", + condition = "NOT_SUPPORTED_COMMAND_WITHOUT_HIVE_SUPPORT", parameters = Map("cmd" -> "INSERT OVERWRITE DIRECTORY with the Hive format") ) } @@ -548,7 +548,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[SparkException] { spark.sql(v1) }, - errorClass = "_LEGACY_ERROR_TEMP_2233", + condition = "_LEGACY_ERROR_TEMP_2233", parameters = Map( "providingClass" -> ("class org.apache.spark.sql.execution.datasources." + "jdbc.JdbcRelationProvider")) @@ -658,7 +658,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t select 1L, 2") }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "colName" -> "`i`", @@ -670,7 +670,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t select 1, 2.0") }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "colName" -> "`d`", @@ -682,7 +682,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t select 1, 2.0D, 3") }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "tableColumns" -> "`i`, `d`", @@ -705,7 +705,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t values('a', 'b')") }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "colName" -> "`i`", @@ -716,7 +716,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t values(now(), now())") }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "colName" -> "`i`", @@ -727,7 +727,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t values(true, false)") }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "colName" -> "`i`", @@ -775,7 +775,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[SparkArithmeticException] { sql(s"insert into t values($outOfRangeValue1)") }, - errorClass = "CAST_OVERFLOW_IN_TABLE_INSERT", + condition = "CAST_OVERFLOW_IN_TABLE_INSERT", parameters = Map( "sourceType" -> "\"BIGINT\"", "targetType" -> "\"INT\"", @@ -786,7 +786,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[SparkArithmeticException] { sql(s"insert into t values($outOfRangeValue2)") }, - errorClass = "CAST_OVERFLOW_IN_TABLE_INSERT", + condition = "CAST_OVERFLOW_IN_TABLE_INSERT", parameters = Map( "sourceType" -> "\"BIGINT\"", "targetType" -> "\"INT\"", @@ -806,7 +806,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[SparkArithmeticException] { sql(s"insert into t values(${outOfRangeValue1}D)") }, - errorClass = "CAST_OVERFLOW_IN_TABLE_INSERT", + condition = "CAST_OVERFLOW_IN_TABLE_INSERT", parameters = Map( "sourceType" -> "\"DOUBLE\"", "targetType" -> "\"BIGINT\"", @@ -817,7 +817,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[SparkArithmeticException] { sql(s"insert into t values(${outOfRangeValue2}D)") }, - errorClass = "CAST_OVERFLOW_IN_TABLE_INSERT", + condition = "CAST_OVERFLOW_IN_TABLE_INSERT", parameters = Map( "sourceType" -> "\"DOUBLE\"", "targetType" -> "\"BIGINT\"", @@ -836,7 +836,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[SparkArithmeticException] { sql(s"insert into t values($outOfRangeValue)") }, - errorClass = "CAST_OVERFLOW_IN_TABLE_INSERT", + condition = "CAST_OVERFLOW_IN_TABLE_INSERT", parameters = Map( "sourceType" -> "\"DECIMAL(5,2)\"", "targetType" -> "\"DECIMAL(3,2)\"", @@ -854,7 +854,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("INSERT INTO t VALUES (TIMESTAMP('2010-09-02 14:10:10'), 1)") }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "colName" -> "`i`", @@ -869,7 +869,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("INSERT INTO t VALUES (date('2010-09-02'), 1)") }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "colName" -> "`i`", @@ -884,7 +884,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("INSERT INTO t VALUES (TIMESTAMP('2010-09-02 14:10:10'), true)") }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "colName" -> "`b`", @@ -899,7 +899,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("INSERT INTO t VALUES (date('2010-09-02'), true)") }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "colName" -> "`b`", @@ -971,7 +971,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt") }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", parameters = Map( "tableName" -> "`unknown`", "tableColumns" -> "`a`, `b`", @@ -994,7 +994,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt") }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", parameters = Map( "tableName" -> "`unknown`", "tableColumns" -> "`a`, `b`", @@ -1170,7 +1170,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("create table t(i boolean, s bigint default badvalue) using parquet") }, - errorClass = "INVALID_DEFAULT_VALUE.NOT_CONSTANT", + condition = "INVALID_DEFAULT_VALUE.NOT_CONSTANT", parameters = Map( "statement" -> "CREATE TABLE", "colName" -> "`s`", @@ -1186,7 +1186,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { }, // V1 command still use the fake Analyzer which can't resolve session variables and we // can only report UNRESOLVED_EXPRESSION error. - errorClass = "INVALID_DEFAULT_VALUE.UNRESOLVED_EXPRESSION", + condition = "INVALID_DEFAULT_VALUE.UNRESOLVED_EXPRESSION", parameters = Map( "statement" -> "CREATE TABLE", "colName" -> "`s`", @@ -1199,7 +1199,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { }, // V2 command uses the actual analyzer and can resolve session variables. We can report // a more accurate NOT_CONSTANT error. - errorClass = "INVALID_DEFAULT_VALUE.NOT_CONSTANT", + condition = "INVALID_DEFAULT_VALUE.NOT_CONSTANT", parameters = Map( "statement" -> "CREATE TABLE", "colName" -> "`j`", @@ -1216,7 +1216,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { sql("create table t(i boolean, s bigint default (select min(x) from badtable)) " + "using parquet") }, - errorClass = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", + condition = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", parameters = Map( "statement" -> "CREATE TABLE", "colName" -> "`s`", @@ -1230,7 +1230,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { sql("create table t(i boolean, s bigint default (select min(x) from other)) " + "using parquet") }, - errorClass = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", + condition = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", parameters = Map( "statement" -> "CREATE TABLE", "colName" -> "`s`", @@ -1243,7 +1243,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("create table t(i boolean default (select false as alias), s bigint) using parquet") }, - errorClass = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", + condition = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", parameters = Map( "statement" -> "CREATE TABLE", "colName" -> "`i`", @@ -1256,7 +1256,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t values(false, default + 1)") }, - errorClass = "DEFAULT_PLACEMENT_INVALID", + condition = "DEFAULT_PLACEMENT_INVALID", parameters = Map.empty ) } @@ -1267,7 +1267,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t select false, default + 1") }, - errorClass = "DEFAULT_PLACEMENT_INVALID", + condition = "DEFAULT_PLACEMENT_INVALID", parameters = Map.empty ) } @@ -1277,7 +1277,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t values(false, default)") }, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`t`"), context = ExpectedContext("t", 12, 12) ) @@ -1288,7 +1288,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("create table t(i boolean, s bigint default false) using parquet") }, - errorClass = "INVALID_DEFAULT_VALUE.DATA_TYPE", + condition = "INVALID_DEFAULT_VALUE.DATA_TYPE", parameters = Map( "statement" -> "CREATE TABLE", "colName" -> "`s`", @@ -1306,7 +1306,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { sql("insert into t select t1.id, t2.id, t1.val, t2.val, t1.val * t2.val " + "from num_data t1, num_data t2") }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "tableColumns" -> "`id1`, `int2`, `result`", @@ -1319,7 +1319,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[ParseException] { sql("create table t(i boolean, s bigint default 42L) using parquet") }, - errorClass = "UNSUPPORTED_DEFAULT_VALUE.WITH_SUGGESTION", + condition = "UNSUPPORTED_DEFAULT_VALUE.WITH_SUGGESTION", parameters = Map.empty, context = ExpectedContext("s bigint default 42L", 26, 45) ) @@ -1333,7 +1333,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[ParseException] { sql("insert into t partition(i=default) values(5, default)") }, - errorClass = "REF_DEFAULT_VALUE_IS_NOT_ALLOWED_IN_PARTITION", + condition = "REF_DEFAULT_VALUE_IS_NOT_ALLOWED_IN_PARTITION", parameters = Map.empty, context = ExpectedContext( fragment = "partition(i=default)", @@ -1349,7 +1349,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t values(true)") }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "tableColumns" -> "`i`, `s`", @@ -1423,7 +1423,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t (i, q) select true from (select 1)") }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "tableColumns" -> "`i`, `q`", @@ -1439,7 +1439,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t (i) values (true)") }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "colName" -> "`s`")) @@ -1450,7 +1450,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t (i) values (default)") }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "colName" -> "`s`")) @@ -1461,7 +1461,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t (s) values (default)") }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "colName" -> "`i`")) @@ -1472,7 +1472,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t partition(i='true') (s) values(5)") }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "colName" -> "`q`")) @@ -1483,7 +1483,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t partition(i='false') (q) select 43") }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "colName" -> "`s`")) @@ -1494,7 +1494,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t partition(i='false') (q) select default") }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "colName" -> "`s`")) @@ -1508,7 +1508,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { checkError( exception = intercept[AnalysisException](sql("insert into t (I) select true from (select 1)")), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map("objectName" -> "`I`", "proposal" -> "`i`, `s`"), context = ExpectedContext( @@ -1640,7 +1640,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("alter table t add column s bigint default badvalue") }, - errorClass = "INVALID_DEFAULT_VALUE.UNRESOLVED_EXPRESSION", + condition = "INVALID_DEFAULT_VALUE.UNRESOLVED_EXPRESSION", parameters = Map( "statement" -> "ALTER TABLE ADD COLUMNS", "colName" -> "`s`", @@ -1653,7 +1653,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("alter table t add column s bigint default (select min(x) from badtable)") }, - errorClass = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", + condition = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", parameters = Map( "statement" -> "ALTER TABLE ADD COLUMNS", "colName" -> "`s`", @@ -1667,7 +1667,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("alter table t add column s bigint default (select min(x) from other)") }, - errorClass = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", + condition = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", parameters = Map( "statement" -> "ALTER TABLE ADD COLUMNS", "colName" -> "`s`", @@ -1680,7 +1680,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("alter table t add column s bigint default false") }, - errorClass = "INVALID_DEFAULT_VALUE.DATA_TYPE", + condition = "INVALID_DEFAULT_VALUE.DATA_TYPE", parameters = Map( "statement" -> "ALTER TABLE ADD COLUMNS", "colName" -> "`s`", @@ -1696,7 +1696,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[ParseException] { sql("alter table t add column s bigint default 42L") }, - errorClass = "UNSUPPORTED_DEFAULT_VALUE.WITH_SUGGESTION", + condition = "UNSUPPORTED_DEFAULT_VALUE.WITH_SUGGESTION", parameters = Map.empty, context = ExpectedContext( fragment = "s bigint default 42L", @@ -1740,7 +1740,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("alter table t alter column s set default badvalue") }, - errorClass = "INVALID_DEFAULT_VALUE.UNRESOLVED_EXPRESSION", + condition = "INVALID_DEFAULT_VALUE.UNRESOLVED_EXPRESSION", parameters = Map( "statement" -> "ALTER TABLE ALTER COLUMN", "colName" -> "`s`", @@ -1750,7 +1750,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("alter table t alter column s set default (select min(x) from badtable)") }, - errorClass = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", + condition = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", parameters = Map( "statement" -> "ALTER TABLE ALTER COLUMN", "colName" -> "`s`", @@ -1761,7 +1761,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("alter table t alter column s set default (select 42 as alias)") }, - errorClass = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", + condition = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", parameters = Map( "statement" -> "ALTER TABLE ALTER COLUMN", "colName" -> "`s`", @@ -1771,7 +1771,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("alter table t alter column s set default false") }, - errorClass = "INVALID_DEFAULT_VALUE.DATA_TYPE", + condition = "INVALID_DEFAULT_VALUE.DATA_TYPE", parameters = Map( "statement" -> "ALTER TABLE ALTER COLUMN", "colName" -> "`s`", @@ -1785,7 +1785,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[ParseException] { sql(sqlText) }, - errorClass = "UNSUPPORTED_DEFAULT_VALUE.WITH_SUGGESTION", + condition = "UNSUPPORTED_DEFAULT_VALUE.WITH_SUGGESTION", parameters = Map.empty, context = ExpectedContext( fragment = sqlText, @@ -1800,7 +1800,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("alter table t alter column i set default false") }, - errorClass = "CANNOT_ALTER_PARTITION_COLUMN", + condition = "CANNOT_ALTER_PARTITION_COLUMN", parameters = Map("tableName" -> "`spark_catalog`.`default`.`t`", "columnName" -> "`i`") ) } @@ -1964,7 +1964,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq("xyz").toDF().select("value", "default").write.insertInto("t") }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`default`", "proposal" -> "`value`"), context = ExpectedContext(fragment = "select", callSitePattern = getCurrentClassCallSitePattern)) @@ -1998,7 +1998,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(s"create table t(a string default 'abc') using parquet") }, - errorClass = "_LEGACY_ERROR_TEMP_1345", + condition = "DEFAULT_UNSUPPORTED", parameters = Map("statementType" -> "CREATE TABLE", "dataSource" -> "parquet")) withTable("t") { sql(s"create table t(a string, b int) using parquet") @@ -2006,7 +2006,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("alter table t add column s bigint default 42") }, - errorClass = "_LEGACY_ERROR_TEMP_1345", + condition = "DEFAULT_UNSUPPORTED", parameters = Map( "statementType" -> "ALTER TABLE ADD COLUMNS", "dataSource" -> "parquet")) @@ -2065,7 +2065,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("alter table t add column s array default array('abc', 'def')") }, - errorClass = "INVALID_DEFAULT_VALUE.DATA_TYPE", + condition = "INVALID_DEFAULT_VALUE.DATA_TYPE", parameters = Map( "statement" -> "ALTER TABLE ADD COLUMNS", "colName" -> "`s`", @@ -2128,7 +2128,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("alter table t add column s struct default struct(42, 56)") }, - errorClass = "INVALID_DEFAULT_VALUE.DATA_TYPE", + condition = "INVALID_DEFAULT_VALUE.DATA_TYPE", parameters = Map( "statement" -> "ALTER TABLE ADD COLUMNS", "colName" -> "`s`", @@ -2248,7 +2248,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("alter table t add column s map default map(42, 56)") }, - errorClass = "INVALID_DEFAULT_VALUE.DATA_TYPE", + condition = "INVALID_DEFAULT_VALUE.DATA_TYPE", parameters = Map( "statement" -> "ALTER TABLE ADD COLUMNS", "colName" -> "`s`", @@ -2264,7 +2264,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("create table t(a string default (select 'abc')) using parquet") }, - errorClass = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", + condition = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", parameters = Map( "statement" -> "CREATE TABLE", "colName" -> "`a`", @@ -2273,7 +2273,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("create table t(a string default exists(select 42 where true)) using parquet") }, - errorClass = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", + condition = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", parameters = Map( "statement" -> "CREATE TABLE", "colName" -> "`a`", @@ -2282,7 +2282,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("create table t(a string default 1 in (select 1 union all select 2)) using parquet") }, - errorClass = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", + condition = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", parameters = Map( "statement" -> "CREATE TABLE", "colName" -> "`a`", @@ -2314,7 +2314,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { // provider is now in the denylist. sql(s"alter table t1 add column (b string default 'abc')") }, - errorClass = "_LEGACY_ERROR_TEMP_1346", + condition = "ADD_DEFAULT_UNSUPPORTED", parameters = Map( "statementType" -> "ALTER TABLE ADD COLUMNS", "dataSource" -> provider)) @@ -2389,7 +2389,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { } else { checkError( exception = err, - errorClass = "TASK_WRITE_FAILED", + condition = "TASK_WRITE_FAILED", parameters = Map("path" -> s".*$tableName"), matchPVals = true ) @@ -2419,7 +2419,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[ParseException] { sql("insert overwrite local directory 'hdfs:/abcd' using parquet select 1") }, - errorClass = "LOCAL_MUST_WITH_SCHEMA_FILE", + condition = "LOCAL_MUST_WITH_SCHEMA_FILE", parameters = Map("actualSchema" -> "hdfs"), context = ExpectedContext( fragment = "insert overwrite local directory 'hdfs:/abcd' using parquet", @@ -2439,7 +2439,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("INSERT INTO TABLE insertTable PARTITION(part1=1, part2='') SELECT 1") }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map( "details" -> ("The spec ([part1=Some(1), part2=Some()]) " + "contains an empty partition column value")) @@ -2448,7 +2448,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("INSERT INTO TABLE insertTable PARTITION(part1='', part2) SELECT 1 ,'' AS part2") }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map( "details" -> ("The spec ([part1=Some(), part2=None]) " + "contains an empty partition column value")) @@ -2475,7 +2475,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { |) """.stripMargin) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`c3`", @@ -2705,7 +2705,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { spark.table(tableName).write.mode(SaveMode.Overwrite).saveAsTable(tableName) }, - errorClass = "UNSUPPORTED_OVERWRITE.TABLE", + condition = "UNSUPPORTED_OVERWRITE.TABLE", parameters = Map("table" -> s"`spark_catalog`.`default`.`$tableName`") ) } @@ -2726,7 +2726,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(insertDirSql) }, - errorClass = "UNSUPPORTED_OVERWRITE.PATH", + condition = "UNSUPPORTED_OVERWRITE.PATH", parameters = Map("path" -> ("file:" + path))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index 730b63850d99a..f3849fe34ec29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -168,7 +168,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq((3, 2)).toDF("a", "b").write.partitionBy("b", "b").csv(f.getAbsolutePath) }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`b`")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 6067efc1d1c1c..ee40d70c88291 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -84,7 +84,7 @@ class ResolvedDataSourceSuite extends SharedSparkSession { exception = intercept[AnalysisException] { getProvidingClass(provider) }, - errorClass = "_LEGACY_ERROR_TEMP_1139", + condition = "_LEGACY_ERROR_TEMP_1139", parameters = Map("provider" -> provider) ) } @@ -95,7 +95,7 @@ class ResolvedDataSourceSuite extends SharedSparkSession { exception = intercept[AnalysisException] { getProvidingClass("kafka") }, - errorClass = "_LEGACY_ERROR_TEMP_1140", + condition = "_LEGACY_ERROR_TEMP_1140", parameters = Map("provider" -> "kafka") ) } @@ -106,7 +106,7 @@ class ResolvedDataSourceSuite extends SharedSparkSession { } checkError( exception = error, - errorClass = "DATA_SOURCE_NOT_FOUND", + condition = "DATA_SOURCE_NOT_FOUND", parameters = Map("provider" -> "asfdwefasdfasdf") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index d1fe601838cb6..e27ec32e287e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -50,7 +50,7 @@ abstract class FileStreamSinkSuite extends StreamTest { override def beforeAll(): Unit = { super.beforeAll() - spark.conf.set(SQLConf.ORC_IMPLEMENTATION, "native") + spark.conf.set(SQLConf.ORC_IMPLEMENTATION.key, "native") } override def afterAll(): Unit = { @@ -280,7 +280,7 @@ abstract class FileStreamSinkSuite extends StreamTest { exception = intercept[AnalysisException] { df.writeStream.format("parquet").outputMode(mode).start(dir.getCanonicalPath) }, - errorClass = "STREAMING_OUTPUT_MODE.UNSUPPORTED_DATASOURCE", + condition = "STREAMING_OUTPUT_MODE.UNSUPPORTED_DATASOURCE", sqlState = "42KDE", parameters = Map("className" -> "parquet", "outputMode" -> mode)) } @@ -378,7 +378,7 @@ abstract class FileStreamSinkSuite extends StreamTest { exception = intercept[AnalysisException] { spark.read.schema(s"$c0 INT, $c1 INT").json(outputDir).as[(Int, Int)] }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${c1.toLowerCase(Locale.ROOT)}`")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 634311b669a85..773be0cc08e3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -262,7 +262,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { override def beforeAll(): Unit = { super.beforeAll() - spark.conf.set(SQLConf.ORC_IMPLEMENTATION, "native") + spark.conf.set(SQLConf.ORC_IMPLEMENTATION.key, "native") } override def afterAll(): Unit = { @@ -419,7 +419,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { createFileStreamSourceAndGetSchema( format = Some("json"), path = Some(src.getCanonicalPath), schema = None) }, - errorClass = "UNABLE_TO_INFER_SCHEMA", + condition = "UNABLE_TO_INFER_SCHEMA", parameters = Map("format" -> "JSON") ) } @@ -1504,7 +1504,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { // This is to avoid running a spark job to list of files in parallel // by the InMemoryFileIndex. - spark.conf.set(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD, numFiles * 2) + spark.conf.set(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key, numFiles * 2) withTempDirs { case (root, tmp) => val src = new File(root, "a=1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 45a80a210fcee..f7ff39622ed40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -23,7 +23,6 @@ import java.sql.Timestamp import org.apache.commons.io.FileUtils import org.scalatest.exceptions.TestFailedException -import org.apache.spark.SparkException import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction import org.apache.spark.sql.{DataFrame, Encoder} import org.apache.spark.sql.catalyst.InternalRow @@ -635,6 +634,72 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { ) } + testWithAllStateVersions("[SPARK-49474] flatMapGroupsWithState - user NPE is classified") { + // Throws NPE + val stateFunc = (_: String, _: Iterator[String], _: GroupState[RunningCount]) => { + throw new NullPointerException() + // Need to return an iterator for compilation to get types + Iterator(1) + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) + + testStream(result, Update)( + AddData(inputData, "a"), + ExpectFailure[FlatMapGroupsWithStateUserFuncException]() + ) + } + + testWithAllStateVersions( + "[SPARK-49474] flatMapGroupsWithState - null user iterator error is classified") { + // Returns null, will throw NPE when method is called on it + val stateFunc = (_: String, _: Iterator[String], _: GroupState[RunningCount]) => { + null.asInstanceOf[Iterator[Int]] + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) + + testStream(result, Update)( + AddData(inputData, "a"), + ExpectFailure[FlatMapGroupsWithStateUserFuncException]() + ) + } + + testWithAllStateVersions( + "[SPARK-49474] flatMapGroupsWithState - NPE from user iterator is classified") { + // Returns iterator that throws NPE when next is called + val stateFunc = (_: String, _: Iterator[String], _: GroupState[RunningCount]) => { + new Iterator[Int] { + override def hasNext: Boolean = { + true + } + + override def next(): Int = { + throw new NullPointerException() + } + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) + + testStream(result, Update)( + AddData(inputData, "a"), + ExpectFailure[FlatMapGroupsWithStateUserFuncException]() + ) + } + test("mapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) @@ -816,7 +881,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { CheckNewAnswer(("a", 2L)), setFailInTask(true), AddData(inputData, "a"), - ExpectFailure[SparkException](), // task should fail but should not increment count + // task should fail but should not increment count + ExpectFailure[FlatMapGroupsWithStateUserFuncException](), setFailInTask(false), StartStream(), CheckNewAnswer(("a", 3L)) // task should not fail, and should show correct count @@ -1097,7 +1163,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { func: (Int, Iterator[Int], GroupState[Int]) => Iterator[Int], timeoutType: GroupStateTimeout = GroupStateTimeout.NoTimeout, batchTimestampMs: Long = NO_TIMESTAMP): FlatMapGroupsWithStateExec = { - val stateFormatVersion = spark.conf.get(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) + val stateFormatVersion = sqlConf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) val emptyRdd = spark.sparkContext.emptyRDD[InternalRow] MemoryStream[Int] .toDS() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/GroupStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/GroupStateSuite.scala index 050c1a2d7d978..80c87d3297b01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/GroupStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/GroupStateSuite.scala @@ -302,13 +302,13 @@ class GroupStateSuite extends SparkFunSuite { TestGroupState.create[Int]( Optional.of(5), NoTimeout, 100L, Optional.empty[Long], hasTimedOut = true) }, - errorClass = "_LEGACY_ERROR_TEMP_3168", + condition = "_LEGACY_ERROR_TEMP_3168", parameters = Map.empty) checkError( exception = intercept[SparkUnsupportedOperationException] { GroupStateImpl.createForStreaming[Int](Some(5), 100L, NO_TIMESTAMP, NoTimeout, true, false) }, - errorClass = "_LEGACY_ERROR_TEMP_3168", + condition = "_LEGACY_ERROR_TEMP_3168", parameters = Map.empty) } @@ -349,7 +349,7 @@ class GroupStateSuite extends SparkFunSuite { def assertWrongTimeoutError(test: => Unit): Unit = { checkError( exception = intercept[SparkUnsupportedOperationException] { test }, - errorClass = "_LEGACY_ERROR_TEMP_2204", + condition = "_LEGACY_ERROR_TEMP_2204", parameters = Map.empty) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 7ab45e25799bc..68436c4e355b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -542,10 +542,7 @@ trait StreamTest extends QueryTest with SharedSparkSession with TimeLimits with val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation) additionalConfs.foreach(pair => { - val value = - if (sparkSession.conf.contains(pair._1)) { - Some(sparkSession.conf.get(pair._1)) - } else None + val value = sparkSession.conf.getOption(pair._1) resetConfValues(pair._1) = value sparkSession.conf.set(pair._1, pair._2) }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala index 854893b1f033e..ab9df9a1e5a6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala @@ -521,7 +521,7 @@ class StreamingDeduplicationSuite extends StateStoreMetricsTest { // verify that the key schema not compatible error is thrown checkError( ex.getCause.asInstanceOf[SparkUnsupportedOperationException], - errorClass = "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE", + condition = "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE", parameters = Map("storedKeySchema" -> ".*", "newKeySchema" -> ".*"), matchPVals = true @@ -567,7 +567,7 @@ class StreamingDeduplicationSuite extends StateStoreMetricsTest { checkError( ex.getCause.asInstanceOf[SparkUnsupportedOperationException], - errorClass = "STATE_STORE_UNSUPPORTED_OPERATION_BINARY_INEQUALITY", + condition = "STATE_STORE_UNSUPPORTED_OPERATION_BINARY_INEQUALITY", parameters = Map( "schema" -> ".+\"str\":\"spark.UTF8_LCASE\".+" ), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 5e9bdad8fd825..a733d54d275d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -752,7 +752,7 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { checkError( ex.getCause.asInstanceOf[SparkUnsupportedOperationException], - errorClass = "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE", + condition = "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE", parameters = Map("storedKeySchema" -> ".*", "newKeySchema" -> ".*"), matchPVals = true @@ -822,7 +822,7 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { checkError( ex.getCause.asInstanceOf[SparkUnsupportedOperationException], - errorClass = "STATE_STORE_VALUE_SCHEMA_NOT_COMPATIBLE", + condition = "STATE_STORE_VALUE_SCHEMA_NOT_COMPATIBLE", parameters = Map("storedValueSchema" -> ".*", "newValueSchema" -> ".*"), matchPVals = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala index 782badaef924f..f651bfb7f3c72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala @@ -21,7 +21,7 @@ import java.sql.Timestamp import org.apache.spark.sql.Row import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.functions.{expr, lit, window} +import org.apache.spark.sql.functions.{count, expr, lit, timestamp_seconds, window} import org.apache.spark.sql.internal.SQLConf /** @@ -524,4 +524,66 @@ class StreamingQueryOptimizationCorrectnessSuite extends StreamTest { doTest(numExpectedStatefulOperatorsForOneEmptySource = 1) } } + + test("SPARK-49699: observe node is not pruned out from PruneFilters") { + val input1 = MemoryStream[Int] + val df = input1.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .observe("observation", count(lit(1)).as("rows")) + // Enforce PruneFilters to come into play and prune subtree. We could do the same + // with the reproducer of SPARK-48267, but let's just be simpler. + .filter(expr("false")) + + testStream(df)( + AddData(input1, 1, 2, 3), + CheckNewAnswer(), + Execute { qe => + val observeRow = qe.lastExecution.observedMetrics.get("observation") + assert(observeRow.get.getAs[Long]("rows") == 3L) + } + ) + } + + test("SPARK-49699: watermark node is not pruned out from PruneFilters") { + // NOTE: The test actually passes without SPARK-49699, because of the trickiness of + // filter pushdown and PruneFilters. Unlike observe node, the `false` filter is pushed down + // below to watermark node, hence PruneFilters rule does not prune out watermark node even + // before SPARK-49699. Propagate empty relation does not also propagate emptiness into + // watermark node, so the node is retained. The test is added for preventing regression. + + val input1 = MemoryStream[Int] + val df = input1.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "0 second") + // Enforce PruneFilter to come into play and prune subtree. We could do the same + // with the reproducer of SPARK-48267, but let's just be simpler. + .filter(expr("false")) + + testStream(df)( + AddData(input1, 1, 2, 3), + CheckNewAnswer(), + Execute { qe => + // If the watermark node is pruned out, this would be null. + assert(qe.lastProgress.eventTime.get("watermark") != null) + } + ) + } + + test("SPARK-49699: stateful operator node is not pruned out from PruneFilters") { + val input1 = MemoryStream[Int] + val df = input1.toDF() + .groupBy("value") + .count() + // Enforce PruneFilter to come into play and prune subtree. We could do the same + // with the reproducer of SPARK-48267, but let's just be simpler. + .filter(expr("false")) + + testStream(df, OutputMode.Complete())( + AddData(input1, 1, 2, 3), + CheckNewAnswer(), + Execute { qe => + assert(qe.lastProgress.stateOperators.length == 1) + } + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 2767f2dd46b2e..8471995cb1e50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -1422,7 +1422,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } checkError( ex.getCause.asInstanceOf[SparkUnsupportedOperationException], - errorClass = "STATE_STORE_UNSUPPORTED_OPERATION_BINARY_INEQUALITY", + condition = "STATE_STORE_UNSUPPORTED_OPERATION_BINARY_INEQUALITY", parameters = Map( "schema" -> ".+\"c1\":\"spark.UTF8_LCASE\".+" ), @@ -1457,7 +1457,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi .option("checkpointLocation", checkpointDir.getCanonicalPath) .start(outputDir.getCanonicalPath) }, - errorClass = "STREAMING_OUTPUT_MODE.UNSUPPORTED_OPERATION", + condition = "STREAMING_OUTPUT_MODE.UNSUPPORTED_OPERATION", sqlState = "42KDE", parameters = Map( "outputMode" -> "append", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala index bf4e3f0a4e4aa..ec3c145af686c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala @@ -607,7 +607,7 @@ class StreamingSessionWindowSuite extends StreamTest CheckAnswer() // this is just to trigger the exception ) }, - errorClass = "STREAMING_OUTPUT_MODE.UNSUPPORTED_OPERATION", + condition = "STREAMING_OUTPUT_MODE.UNSUPPORTED_OPERATION", sqlState = "42KDE", parameters = Map( "outputMode" -> OutputMode.Update().toString.toLowerCase(Locale.ROOT), @@ -625,7 +625,7 @@ class StreamingSessionWindowSuite extends StreamTest CheckAnswer() // this is just to trigger the exception ) }, - errorClass = "STREAMING_OUTPUT_MODE.UNSUPPORTED_OPERATION", + condition = "STREAMING_OUTPUT_MODE.UNSUPPORTED_OPERATION", sqlState = "42KDE", parameters = Map( "outputMode" -> OutputMode.Update().toString.toLowerCase(Locale.ROOT), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala index 0f394aac8f782..fe88fbaa91cb7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala @@ -103,7 +103,7 @@ class TransformWithMapStateSuite extends StreamTest ExpectFailure[SparkIllegalArgumentException] { e => { checkError( exception = e.asInstanceOf[SparkIllegalArgumentException], - errorClass = "ILLEGAL_STATE_STORE_VALUE.NULL_VALUE", + condition = "ILLEGAL_STATE_STORE_VALUE.NULL_VALUE", sqlState = Some("42601"), parameters = Map("stateName" -> "sessionState") ) @@ -152,7 +152,7 @@ class TransformWithMapStateSuite extends StreamTest ExpectFailure[SparkIllegalArgumentException] { e => { checkError( exception = e.asInstanceOf[SparkIllegalArgumentException], - errorClass = "ILLEGAL_STATE_STORE_VALUE.NULL_VALUE", + condition = "ILLEGAL_STATE_STORE_VALUE.NULL_VALUE", sqlState = Some("42601"), parameters = Map("stateName" -> "sessionState")) }} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala index b32d3c7e52013..b1025d9d89494 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala @@ -191,7 +191,7 @@ class TransformWithStateChainingSuite extends StreamTest { StartStream() ) }, - errorClass = "STREAMING_OUTPUT_MODE.UNSUPPORTED_OPERATION", + condition = "STREAMING_OUTPUT_MODE.UNSUPPORTED_OPERATION", sqlState = "42KDE", parameters = Map( "outputMode" -> "append", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala index 54cff6fc44c08..d141407b4fcd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala @@ -418,7 +418,7 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest } checkError( exception = e.getCause.asInstanceOf[SparkUnsupportedOperationException], - errorClass = "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY", + condition = "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY", sqlState = Some("42802"), parameters = Map("groupingKey" -> "init_1") ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 9eeedd8598092..d0e255bb30499 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -21,12 +21,16 @@ import java.io.File import java.time.Duration import java.util.UUID -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} +import org.scalatest.matchers.must.Matchers.be +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper +import org.scalatest.time.{Seconds, Span} import org.apache.spark.{SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, Encoders, Row} import org.apache.spark.sql.catalyst.util.stringToFile +import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state._ @@ -708,7 +712,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest } checkError( ex.asInstanceOf[SparkRuntimeException], - errorClass = "STATE_STORE_HANDLE_NOT_INITIALIZED", + condition = "STATE_STORE_HANDLE_NOT_INITIALIZED", parameters = Map.empty ) } @@ -1151,7 +1155,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest ExpectFailure[StateStoreInvalidConfigAfterRestart] { e => checkError( e.asInstanceOf[SparkUnsupportedOperationException], - errorClass = "STATE_STORE_INVALID_CONFIG_AFTER_RESTART", + condition = "STATE_STORE_INVALID_CONFIG_AFTER_RESTART", parameters = Map( "configName" -> "outputMode", "oldConfig" -> "Update", @@ -1193,7 +1197,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest ExpectFailure[StateStoreInvalidVariableTypeChange] { t => checkError( t.asInstanceOf[SparkUnsupportedOperationException], - errorClass = "STATE_STORE_INVALID_VARIABLE_TYPE_CHANGE", + condition = "STATE_STORE_INVALID_VARIABLE_TYPE_CHANGE", parameters = Map( "stateVarName" -> "countState", "newType" -> "ListState", @@ -1240,7 +1244,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest ExpectFailure[StateStoreInvalidConfigAfterRestart] { e => checkError( e.asInstanceOf[SparkUnsupportedOperationException], - errorClass = "STATE_STORE_INVALID_CONFIG_AFTER_RESTART", + condition = "STATE_STORE_INVALID_CONFIG_AFTER_RESTART", parameters = Map( "configName" -> "timeMode", "oldConfig" -> "NoTime", @@ -1292,7 +1296,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest ExpectFailure[StateStoreValueSchemaNotCompatible] { t => checkError( t.asInstanceOf[SparkUnsupportedOperationException], - errorClass = "STATE_STORE_VALUE_SCHEMA_NOT_COMPATIBLE", + condition = "STATE_STORE_VALUE_SCHEMA_NOT_COMPATIBLE", parameters = Map( "storedValueSchema" -> "StructType(StructField(value,LongType,false))", "newValueSchema" -> @@ -1426,6 +1430,495 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } } + + private def getFiles(path: Path): Array[FileStatus] = { + val hadoopConf = spark.sessionState.newHadoopConf() + val fileManager = CheckpointFileManager.create(path, hadoopConf) + fileManager.list(path) + } + + private def getStateSchemaPath(stateCheckpointPath: Path): Path = { + new Path(stateCheckpointPath, "_stateSchema/default/") + } + + test("transformWithState - verify that metadata and schema logs are purged") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.MIN_BATCHES_TO_RETAIN.key -> "1") { + withTempDir { chkptDir => + val stateOpIdPath = new Path(new Path(chkptDir.getCanonicalPath, "state"), "0") + val stateSchemaPath = getStateSchemaPath(stateOpIdPath) + + val metadataPath = OperatorStateMetadataV2.metadataDirPath(stateOpIdPath) + // in this test case, we are changing the state spec back and forth + // to trigger the writing of the schema and metadata files + val inputData = MemoryStream[(String, String)] + val result1 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new MostRecentStatefulProcessorWithDeletion(), + TimeMode.None(), + OutputMode.Update()) + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str2")), + CheckNewAnswer(("a", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + // assert that a metadata and schema file has been written for each run + // as state variables have been deleted + assert(getFiles(metadataPath).length == 2) + assert(getFiles(stateSchemaPath).length == 2) + + val result3 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + testStream(result3, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str3")), + CheckNewAnswer(("a", "1", "str2")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + // because we don't change the schema for this run, there won't + // be a new schema file written. + testStream(result3, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str4")), + CheckNewAnswer(("a", "2", "str3")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + // by the end of the test, there have been 4 batches, + // so the metadata and schema logs, and commitLog has been purged + // for batches 0 and 1 so metadata and schema files exist for batches 0, 1, 2, 3 + // and we only need to keep metadata files for batches 2, 3, and the since schema + // hasn't changed between batches 2, 3, we only keep the schema file for batch 2 + assert(getFiles(metadataPath).length == 2) + assert(getFiles(stateSchemaPath).length == 1) + } + } + } + + test("transformWithState - verify that schema file is kept after metadata is purged") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2") { + withTempDir { chkptDir => + val stateOpIdPath = new Path(new Path(chkptDir.getCanonicalPath, "state"), "0") + val stateSchemaPath = getStateSchemaPath(stateOpIdPath) + + val metadataPath = OperatorStateMetadataV2.metadataDirPath(stateOpIdPath) + // in this test case, we are changing the state spec back and forth + // to trigger the writing of the schema and metadata files + val inputData = MemoryStream[(String, String)] + val result1 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "2", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new MostRecentStatefulProcessorWithDeletion(), + TimeMode.None(), + OutputMode.Update()) + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str2")), + CheckNewAnswer(("a", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + assert(getFiles(metadataPath).length == 3) + assert(getFiles(stateSchemaPath).length == 2) + + val result3 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + testStream(result3, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str3")), + CheckNewAnswer(("a", "1", "str2")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + // metadata files should be kept for batches 1, 2, 3 + // schema files should be kept for batches 0, 2, 3 + assert(getFiles(metadataPath).length == 3) + assert(getFiles(stateSchemaPath).length == 3) + // we want to ensure that we can read batch 1 even though the + // metadata file for batch 0 was removed + val batch1Df = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, chkptDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "countState") + .option(StateSourceOptions.BATCH_ID, 1) + .load() + + val batch1AnsDf = batch1Df.selectExpr( + "key.value AS groupingKey", + "single_value.value AS valueId") + + checkAnswer(batch1AnsDf, Seq(Row("a", 2L))) + + val batch3Df = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, chkptDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "countState") + .option(StateSourceOptions.BATCH_ID, 3) + .load() + + val batch3AnsDf = batch3Df.selectExpr( + "key.value AS groupingKey", + "single_value.value AS valueId") + checkAnswer(batch3AnsDf, Seq(Row("a", 1L))) + } + } + } + + test("state data source integration - value state supports time travel") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.MIN_BATCHES_TO_RETAIN.key -> "5") { + withTempDir { chkptDir => + // in this test case, we are changing the state spec back and forth + // to trigger the writing of the schema and metadata files + val inputData = MemoryStream[(String, String)] + val result1 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "2", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "3", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "4", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "5", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "6", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "7", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new MostRecentStatefulProcessorWithDeletion(), + TimeMode.None(), + OutputMode.Update()) + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str2")), + CheckNewAnswer(("a", "str1")), + AddData(inputData, ("a", "str3")), + CheckNewAnswer(("a", "str2")), + AddData(inputData, ("a", "str4")), + CheckNewAnswer(("a", "str3")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + + // Batches 0-7: countState, mostRecent + // Batches 8-9: countState + + // By this time, offset and commit logs for batches 0-3 have been purged. + // However, if we want to read the data for batch 4, we need to read the corresponding + // metadata and schema file for batch 4, or the latest files that correspond to + // batch 4 (in this case, the files were written for batch 0). + // We want to test the behavior where the metadata files are preserved so that we can + // read from the state data source, even if the commit and offset logs are purged for + // a particular batch + + val df = spark.read.format("state-metadata").load(chkptDir.toString) + + // check the min and max batch ids that we have data for + checkAnswer( + df.select( + "operatorId", "operatorName", "stateStoreName", "numPartitions", "minBatchId", + "maxBatchId"), + Seq(Row(0, "transformWithStateExec", "default", 5, 4, 9)) + ) + + val countStateDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, chkptDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "countState") + .option(StateSourceOptions.BATCH_ID, 4) + .load() + + val countStateAnsDf = countStateDf.selectExpr( + "key.value AS groupingKey", + "single_value.value AS valueId") + checkAnswer(countStateAnsDf, Seq(Row("a", 5L))) + + val mostRecentDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, chkptDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "mostRecent") + .option(StateSourceOptions.BATCH_ID, 4) + .load() + + val mostRecentAnsDf = mostRecentDf.selectExpr( + "key.value AS groupingKey", + "single_value.value") + checkAnswer(mostRecentAnsDf, Seq(Row("a", "str1"))) + } + } + } + + test("transformWithState - verify that all metadata and schema logs are not purged") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.MIN_BATCHES_TO_RETAIN.key -> "3") { + withTempDir { chkptDir => + val inputData = MemoryStream[(String, String)] + val result1 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "2", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "3", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "4", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "5", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "6", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "7", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "8", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "9", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "10", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "11", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "12", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "13", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "14", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + + val stateOpIdPath = new Path(new Path(chkptDir.getCanonicalPath, "state"), "0") + val stateSchemaPath = getStateSchemaPath(stateOpIdPath) + + val metadataPath = OperatorStateMetadataV2.metadataDirPath(stateOpIdPath) + + // Metadata files exist for batches 0, 12, and the thresholdBatchId is 8 + // as this is the earliest batchId for which the commit log is not present, + // so we need to keep metadata files for batch 0 so we can read the commit + // log correspondingly + assert(getFiles(metadataPath).length == 2) + assert(getFiles(stateSchemaPath).length == 1) + } + } + } + + test("transformWithState - verify that no metadata and schema logs are purged after" + + " removing column family") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.MIN_BATCHES_TO_RETAIN.key -> "3") { + withTempDir { chkptDir => + val inputData = MemoryStream[(String, String)] + val result1 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "2", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "1", "")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "2", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "3", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "4", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "5", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "6", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "7", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "8", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "9", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "10", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "11", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "12", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new MostRecentStatefulProcessorWithDeletion(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("b", "str2")), + CheckNewAnswer(("b", "str1")), + AddData(inputData, ("b", "str3")), + CheckNewAnswer(("b", "str2")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + + val stateOpIdPath = new Path(new Path(chkptDir.getCanonicalPath, "state"), "0") + val stateSchemaPath = getStateSchemaPath(stateOpIdPath) + + val metadataPath = OperatorStateMetadataV2.metadataDirPath(stateOpIdPath) + + // Metadata files are written for batches 0, 2, and 14. + // Schema files are written for 0, 14 + // At the beginning of the last query run, the thresholdBatchId is 11. + // However, we would need both schema files to be preserved, if we want to + // be able to read from batch 11 onwards. + assert(getFiles(metadataPath).length == 2) + assert(getFiles(stateSchemaPath).length == 2) + } + } + } } class TransformWithStateValidationSuite extends StateStoreMetricsTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala index defd5fd110de6..a47c2f839692c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala @@ -265,7 +265,7 @@ class TriggerAvailableNowSuite extends FileStreamSourceTest { private def assertQueryUsingRightBatchExecutor( testSource: TestDataFrameProvider, query: StreamingQuery): Unit = { - val useWrapper = query.sparkSession.conf.get( + val useWrapper = query.sparkSession.sessionState.conf.getConf( SQLConf.STREAMING_TRIGGER_AVAILABLE_NOW_WRAPPER_ENABLED) if (useWrapper) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index e77ba92fe2981..544f910333bfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -324,7 +324,7 @@ class StreamingDataSourceV2Suite extends StreamTest { readFormat: String, writeFormat: String, trigger: Trigger, - errorClass: String, + condition: String, parameters: Map[String, String]) = { val query = spark.readStream .format(readFormat) @@ -339,7 +339,7 @@ class StreamingDataSourceV2Suite extends StreamTest { assert(query.exception.get.cause != null) checkErrorMatchPVals( exception = query.exception.get.cause.asInstanceOf[SparkUnsupportedOperationException], - errorClass = errorClass, + condition = condition, parameters = parameters ) } @@ -436,7 +436,7 @@ class StreamingDataSourceV2Suite extends StreamTest { exception = intercept[SparkUnsupportedOperationException] { testCase(read, write, trigger) }, - errorClass = "_LEGACY_ERROR_TEMP_2049", + condition = "_LEGACY_ERROR_TEMP_2049", parameters = Map( "className" -> "fake-read-neither-mode", "operator" -> "reading" @@ -449,7 +449,7 @@ class StreamingDataSourceV2Suite extends StreamTest { exception = intercept[SparkUnsupportedOperationException] { testCase(read, write, trigger) }, - errorClass = "_LEGACY_ERROR_TEMP_2049", + condition = "_LEGACY_ERROR_TEMP_2049", parameters = Map( "className" -> "fake-write-neither-mode", "operator" -> "writing" @@ -466,7 +466,7 @@ class StreamingDataSourceV2Suite extends StreamTest { exception = intercept[SparkUnsupportedOperationException] { testCase(read, write, trigger) }, - errorClass = "_LEGACY_ERROR_TEMP_2253", + condition = "_LEGACY_ERROR_TEMP_2253", parameters = Map("sourceName" -> "fake-read-microbatch-only") ) } @@ -478,7 +478,7 @@ class StreamingDataSourceV2Suite extends StreamTest { } else { // Invalid - trigger is microbatch but reader is not testPostCreationNegativeCase(read, write, trigger, - errorClass = "_LEGACY_ERROR_TEMP_2209", + condition = "_LEGACY_ERROR_TEMP_2209", parameters = Map( "srcName" -> read, "disabledSources" -> "", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index c4ec0af80b725..e74627f3f51e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -133,7 +133,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { .write .save() }, - errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", + condition = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", parameters = Map("methodName" -> "`write`")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala index 5ae7b3eec37e7..86c4e49f6f66f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala @@ -117,7 +117,7 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { exception = intercept[AnalysisException] { spark.readStream.table(tableIdentifier) }, - errorClass = "UNSUPPORTED_FEATURE.TABLE_OPERATION", + condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", parameters = Map( "tableName" -> "`testcat`.`table_name`", "operation" -> "either micro-batch or continuous scan" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index b9e5c176f93e0..f3d21e384ed42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -165,7 +165,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with } checkError( exception = e, - errorClass = "WRITE_STREAM_NOT_ALLOWED", + condition = "WRITE_STREAM_NOT_ALLOWED", parameters = Map.empty ) } @@ -306,7 +306,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with exception = intercept[AnalysisException] { df.write.mode("append").clusterBy("a").saveAsTable("clusteredTable") }, - errorClass = "CLUSTERING_COLUMNS_MISMATCH", + condition = "CLUSTERING_COLUMNS_MISMATCH", parameters = Map( "tableName" -> "spark_catalog.default.clusteredtable", "specifiedClusteringString" -> """[["a"]]""", @@ -317,7 +317,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with exception = intercept[AnalysisException] { df.write.mode("append").clusterBy("b", "a").saveAsTable("clusteredTable") }, - errorClass = "CLUSTERING_COLUMNS_MISMATCH", + condition = "CLUSTERING_COLUMNS_MISMATCH", parameters = Map( "tableName" -> "spark_catalog.default.clusteredtable", "specifiedClusteringString" -> """[["b"],["a"]]""", @@ -328,7 +328,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with exception = intercept[AnalysisException] { df.write.mode("append").saveAsTable("clusteredTable") }, - errorClass = "CLUSTERING_COLUMNS_MISMATCH", + condition = "CLUSTERING_COLUMNS_MISMATCH", parameters = Map( "tableName" -> "spark_catalog.default.clusteredtable", "specifiedClusteringString" -> "", "existingClusteringString" -> """[["a"],["b"]]""") @@ -455,7 +455,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with exception = intercept[AnalysisException] { Seq((1L, 2.0)).toDF("i", "d").write.mode("append").saveAsTable("t") }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "colName" -> "`i`", @@ -483,7 +483,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with exception = intercept[AnalysisException] { Seq(("a", "b")).toDF("i", "d").write.mode("append").saveAsTable("t") }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "colName" -> "`i`", @@ -495,7 +495,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with exception = intercept[AnalysisException] { Seq((true, false)).toDF("i", "d").write.mode("append").saveAsTable("t") }, - errorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_SAFELY_CAST", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t`", "colName" -> "`i`", @@ -728,7 +728,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with exception = intercept[AnalysisException] { testRead(spark.read.csv(), Seq.empty, schema) }, - errorClass = "UNABLE_TO_INFER_SCHEMA", + condition = "UNABLE_TO_INFER_SCHEMA", parameters = Map("format" -> "CSV") ) @@ -1066,13 +1066,13 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with exception = intercept[AnalysisException] { Seq((1, 1)).toDF("col", c0).write.bucketBy(2, c0, c1).saveAsTable("t") }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${c1.toLowerCase(Locale.ROOT)}`")) checkError( exception = intercept[AnalysisException] { Seq((1, 1)).toDF("col", c0).write.bucketBy(2, "col").sortBy(c0, c1).saveAsTable("t") }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${c1.toLowerCase(Locale.ROOT)}`")) } } @@ -1086,7 +1086,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with Seq((1, 1)).toDF(colName0, colName1).write.format(format).mode("overwrite") .save(tempDir.getAbsolutePath) }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${colName1.toLowerCase(Locale.ROOT)}`")) } @@ -1099,7 +1099,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with spark.read.format(format).schema(s"$colName0 INT, $colName1 INT") .load(testDir.getAbsolutePath) }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${colName1.toLowerCase(Locale.ROOT)}`")) } @@ -1112,7 +1112,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with exception = intercept[AnalysisException] { spark.read.format(format).load(testDir.getAbsolutePath) }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${colName1.toLowerCase(Locale.ROOT)}`")) } @@ -1142,7 +1142,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with exception = intercept[AnalysisException] { spark.read.format("json").option("inferSchema", true).load(testDir.getAbsolutePath) }, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> s"`${c1.toLowerCase(Locale.ROOT)}`")) checkReadPartitionColumnDuplication("json", c0, c1, src) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index ff1473fea369b..4d4cc44eb3e72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -103,6 +103,8 @@ trait SharedSparkSessionBase new TestSparkSession(sparkConf) } + protected def sqlConf: SQLConf = _spark.sessionState.conf + /** * Initialize the [[TestSparkSession]]. Generally, this is just called from * beforeAll; however, in test using styles other than FunSuite, there is diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 808f783a595ac..be91f5e789e2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -118,7 +118,7 @@ class DataFrameCallbackSuite extends QueryTest sparkContext.listenerBus.waitUntilEmpty() assert(metrics.length == 2) - assert(metrics(0)._1 == "foreach") + assert(metrics(0)._1 == "foreachPartition") assert(metrics(1)._1 == "reduce") spark.listenerManager.unregister(listener) diff --git a/sql/gen-sql-api-docs.py b/sql/gen-sql-api-docs.py index 17631a7352a02..3d19da01b3938 100644 --- a/sql/gen-sql-api-docs.py +++ b/sql/gen-sql-api-docs.py @@ -69,19 +69,6 @@ note="", since="1.0.0", deprecated=""), - ExpressionInfo( - className="", - name="between", - usage="expr1 [NOT] BETWEEN expr2 AND expr3 - " + - "evaluate if `expr1` is [not] in between `expr2` and `expr3`.", - arguments="", - examples="\n Examples:\n " + - "> SELECT col1 FROM VALUES 1, 3, 5, 7 WHERE col1 BETWEEN 2 AND 5;\n " + - " 3\n " + - " 5", - note="", - since="1.0.0", - deprecated=""), ExpressionInfo( className="", name="case", diff --git a/sql/gen-sql-functions-docs.py b/sql/gen-sql-functions-docs.py index bb813cffb0128..a1facbaaf7e3b 100644 --- a/sql/gen-sql-functions-docs.py +++ b/sql/gen-sql-functions-docs.py @@ -36,9 +36,14 @@ "bitwise_funcs", "conversion_funcs", "csv_funcs", "xml_funcs", "lambda_funcs", "collection_funcs", "url_funcs", "hash_funcs", "struct_funcs", + "table_funcs", "variant_funcs" } +def _print_red(text): + print('\033[31m' + text + '\033[0m') + + def _list_grouped_function_infos(jvm): """ Returns a list of function information grouped by each group value via JVM. @@ -126,7 +131,13 @@ def _make_pretty_usage(infos): func_name = "\\" + func_name elif (info.name == "when"): func_name = "CASE WHEN" - usages = iter(re.split(r"(.*%s.*) - " % func_name, info.usage.strip())[1:]) + expr_usages = re.split(r"(.*%s.*) - " % func_name, info.usage.strip()) + if len(expr_usages) <= 1: + _print_red("\nThe `usage` of %s is not standardized, please correct it. " + "Refer to: `AesDecrypt`" % (func_name)) + os._exit(-1) + usages = iter(expr_usages[1:]) + for (sig, description) in zip(usages, usages): result.append(" ") result.append(" %s" % sig) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 3ccbd23b71c98..4575549005f33 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -1430,9 +1430,9 @@ abstract class HiveThriftServer2TestBase extends SparkFunSuite with BeforeAndAft protected def jdbcUri(database: String = "default"): String = if (mode == ServerMode.http) { s"""jdbc:hive2://$localhost:$serverPort/ - |$database? - |hive.server2.transport.mode=http; - |hive.server2.thrift.http.path=cliservice; + |$database; + |transportMode=http; + |httpPath=cliservice;? |${hiveConfList}#${hiveVarList} """.stripMargin.split("\n").mkString.trim } else { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala index 026b2388c593c..331572e62f566 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala @@ -103,7 +103,8 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite with SharedThriftServ // SPARK-42921 "timestampNTZ/datetime-special-ansi.sql", // SPARK-47264 - "collations.sql" + "collations.sql", + "pipe-operators.sql" ) override def runQueries( diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index 7005f0e951b2b..dcf3bd8c71731 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { val sessionHandle = client.openSession(user, "") val infoValue = client.getInfo(sessionHandle, GetInfoType.CLI_ODBC_KEYWORDS) // scalastyle:off line.size.limit - assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,JOIN,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") + assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") // scalastyle:on line.size.limit } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index c455c2cef15fd..dbeb8607facc2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, ReplaceCharWithVarchar, ResolveSessionCatalog} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose} import org.apache.spark.sql.catalyst.catalog.{ExternalCatalogWithListener, InvalidUDFClassException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -94,6 +94,8 @@ class HiveSessionStateBuilder( ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: new DetermineTableStats(session) +: + new ResolveTranspose(session) +: + new InvokeProcedures(session) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala index 600fddd797ca4..2a15d5b4dcb45 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala @@ -59,7 +59,7 @@ class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSi exception = intercept[SparkException] { sql("select count(*) from view_refresh").first() }, - errorClass = "FAILED_READ_FILE.FILE_NOT_EXIST", + condition = "FAILED_READ_FILE.FILE_NOT_EXIST", parameters = Map("path" -> ".*") ) @@ -102,7 +102,7 @@ class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSi exception = intercept[SparkException] { sql("select * from test").count() }, - errorClass = "FAILED_READ_FILE.FILE_NOT_EXIST", + condition = "FAILED_READ_FILE.FILE_NOT_EXIST", parameters = Map("path" -> ".*") ) @@ -120,7 +120,7 @@ class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSi exception = intercept[SparkException] { sql("select * from test").count() }, - errorClass = "FAILED_READ_FILE.FILE_NOT_EXIST", + condition = "FAILED_READ_FILE.FILE_NOT_EXIST", parameters = Map("path" -> ".*") ) spark.catalog.refreshByPath(dir.getAbsolutePath) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index f9a24f44b76c0..72c570d1f9097 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -358,7 +358,7 @@ class DataSourceWithHiveMetastoreCatalogSuite |""".stripMargin) checkError( exception = intercept[AnalysisException](spark.table("non_partition_table")), - errorClass = "_LEGACY_ERROR_TEMP_3096", + condition = "_LEGACY_ERROR_TEMP_3096", parameters = Map( "resLen" -> "2", "relLen" -> "1", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index 2152a7e300021..6d7248a7dd67f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -128,7 +128,7 @@ class HiveParquetSuite extends QueryTest } checkError( exception = ex, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`c3`", "proposal" -> "`c1`, `c2`"), context = ExpectedContext( fragment = "c3", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSQLInsertTestSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSQLInsertTestSuite.scala index d6ba38359f496..4109c0a127065 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSQLInsertTestSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSQLInsertTestSuite.scala @@ -45,7 +45,7 @@ class HiveSQLInsertTestSuite extends SQLInsertTestSuite with TestHiveSingleton { v2ErrorClass: String, v1Parameters: Map[String, String], v2Parameters: Map[String, String]): Unit = { - checkError(exception = exception, sqlState = None, errorClass = v1ErrorClass, + checkError(exception = exception, sqlState = None, condition = v1ErrorClass, parameters = v1Parameters) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSharedStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSharedStateSuite.scala index d84b9f7960231..8c6113fb5569d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSharedStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSharedStateSuite.scala @@ -86,7 +86,7 @@ class HiveSharedStateSuite extends SparkFunSuite { assert(ss2.sparkContext.hadoopConfiguration.get("hive.metastore.warehouse.dir") !== invalidPath, "warehouse conf in session options can't affect application wide hadoop conf") assert(ss.conf.get("spark.foo") === "bar2222", "session level conf should be passed to catalog") - assert(!ss.conf.get(WAREHOUSE_PATH).contains(invalidPath), + assert(!ss.conf.get(WAREHOUSE_PATH.key).contains(invalidPath), "session level conf should be passed to catalog") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 69abb1d1673ed..865ce81e151c2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -828,7 +828,7 @@ object SPARK_18360 { .enableHiveSupport().getOrCreate() val defaultDbLocation = spark.catalog.getDatabase("default").locationUri - assert(new Path(defaultDbLocation) == new Path(spark.conf.get(WAREHOUSE_PATH))) + assert(new Path(defaultDbLocation) == new Path(spark.conf.get(WAREHOUSE_PATH.key))) val hiveClient = spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala index ea43f1d2c6729..cc7bb193731f5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala @@ -350,7 +350,7 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter exception = intercept[AnalysisException] { Seq((1, 2, 3, 4)).toDF("a", "b", "c", "d").write.partitionBy("b", "c").insertInto(tableName) }, - errorClass = "_LEGACY_ERROR_TEMP_1309", + condition = "_LEGACY_ERROR_TEMP_1309", parameters = Map.empty ) } @@ -362,7 +362,7 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter exception = intercept[AnalysisException] { sql(s"INSERT INTO TABLE $tableName PARTITION(b=1, c=2) SELECT 1, 2, 3") }, - errorClass = "INSERT_PARTITION_COLUMN_ARITY_MISMATCH", + condition = "INSERT_PARTITION_COLUMN_ARITY_MISMATCH", parameters = Map( "staticPartCols" -> "`b`, `c`", "tableColumns" -> "`a`, `d`, `b`, `c`", @@ -720,7 +720,7 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' |SELECT * FROM test_insert_table""".stripMargin) }, - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'ROW'", "hint" -> "")) } } @@ -740,7 +740,7 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' |SELECT * FROM test_insert_table""".stripMargin) }, - errorClass = "PARSE_SYNTAX_ERROR", + condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'ROW'", "hint" -> "")) } } @@ -809,7 +809,7 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter } checkError( exception = e, - errorClass = "COLUMN_ALREADY_EXISTS", + condition = "COLUMN_ALREADY_EXISTS", parameters = Map("columnName" -> "`id`")) } } @@ -858,7 +858,7 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter |SELECT 1 """.stripMargin) }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map( "details" -> "The spec ([d=Some()]) contains an empty partition column value") ) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index cde0da67e83e9..f2cab33dea76a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -575,7 +575,7 @@ class MetastoreDataSourcesSuite extends QueryTest table("createdJsonTable") }, - errorClass = "UNABLE_TO_INFER_SCHEMA", + condition = "UNABLE_TO_INFER_SCHEMA", parameters = Map("format" -> "JSON") ) @@ -925,7 +925,7 @@ class MetastoreDataSourcesSuite extends QueryTest createDF(10, 19).write.mode(SaveMode.Append).format("orc"). saveAsTable("appendOrcToParquet") }, - errorClass = "_LEGACY_ERROR_TEMP_1159", + condition = "_LEGACY_ERROR_TEMP_1159", parameters = Map( "tableName" -> s"$SESSION_CATALOG_NAME.default.appendorctoparquet", "existingProvider" -> "ParquetDataSourceV2", @@ -941,7 +941,7 @@ class MetastoreDataSourcesSuite extends QueryTest createDF(10, 19).write.mode(SaveMode.Append).format("parquet") .saveAsTable("appendParquetToJson") }, - errorClass = "_LEGACY_ERROR_TEMP_1159", + condition = "_LEGACY_ERROR_TEMP_1159", parameters = Map( "tableName" -> s"$SESSION_CATALOG_NAME.default.appendparquettojson", "existingProvider" -> "JsonDataSourceV2", @@ -957,7 +957,7 @@ class MetastoreDataSourcesSuite extends QueryTest createDF(10, 19).write.mode(SaveMode.Append).format("text") .saveAsTable("appendTextToJson") }, - errorClass = "_LEGACY_ERROR_TEMP_1159", + condition = "_LEGACY_ERROR_TEMP_1159", // The format of the existing table can be JsonDataSourceV2 or JsonFileFormat. parameters = Map( "tableName" -> s"$SESSION_CATALOG_NAME.default.appendtexttojson", @@ -1232,7 +1232,7 @@ class MetastoreDataSourcesSuite extends QueryTest Seq((3, 4)).toDF("i", "k") .write.mode("append").saveAsTable("saveAsTable_mismatch_column_names") }, - errorClass = "_LEGACY_ERROR_TEMP_1162", + condition = "_LEGACY_ERROR_TEMP_1162", parameters = Map("col" -> "j", "inputColumns" -> "i, k")) } } @@ -1245,7 +1245,7 @@ class MetastoreDataSourcesSuite extends QueryTest Seq((3, 4, 5)).toDF("i", "j", "k") .write.mode("append").saveAsTable("saveAsTable_too_many_columns") }, - errorClass = "_LEGACY_ERROR_TEMP_1161", + condition = "_LEGACY_ERROR_TEMP_1161", parameters = Map( "tableName" -> "spark_catalog.default.saveastable_too_many_columns", "existingTableSchema" -> "struct", @@ -1265,7 +1265,7 @@ class MetastoreDataSourcesSuite extends QueryTest |USING hive """.stripMargin) }, - errorClass = "_LEGACY_ERROR_TEMP_1293", + condition = "_LEGACY_ERROR_TEMP_1293", parameters = Map.empty ) } @@ -1288,7 +1288,7 @@ class MetastoreDataSourcesSuite extends QueryTest exception = intercept[AnalysisException] { table(tableName).write.mode(SaveMode.Overwrite).saveAsTable(tableName) }, - errorClass = "UNSUPPORTED_OVERWRITE.TABLE", + condition = "UNSUPPORTED_OVERWRITE.TABLE", parameters = Map("table" -> s"`$SESSION_CATALOG_NAME`.`default`.`tab1`") ) @@ -1296,7 +1296,7 @@ class MetastoreDataSourcesSuite extends QueryTest exception = intercept[AnalysisException] { table(tableName).write.mode(SaveMode.ErrorIfExists).saveAsTable(tableName) }, - errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> s"`$SESSION_CATALOG_NAME`.`default`.`tab1`") ) } @@ -1326,7 +1326,7 @@ class MetastoreDataSourcesSuite extends QueryTest exception = intercept[AnalysisException] { table(tableName).write.mode(SaveMode.Overwrite).insertInto(tableName) }, - errorClass = "UNSUPPORTED_OVERWRITE.TABLE", + condition = "UNSUPPORTED_OVERWRITE.TABLE", parameters = Map("table" -> s"`$SESSION_CATALOG_NAME`.`default`.`tab1`") ) } @@ -1339,7 +1339,7 @@ class MetastoreDataSourcesSuite extends QueryTest exception = intercept[AnalysisException] { Seq(4).toDF("j").write.mode("append").saveAsTable("saveAsTable_less_columns") }, - errorClass = "_LEGACY_ERROR_TEMP_1161", + condition = "_LEGACY_ERROR_TEMP_1161", parameters = Map( "tableName" -> "spark_catalog.default.saveastable_less_columns", "existingTableSchema" -> "struct", @@ -1396,7 +1396,7 @@ class MetastoreDataSourcesSuite extends QueryTest exception = intercept[AnalysisException] { sharedState.externalCatalog.getTable("default", "t") }, - errorClass = "INSUFFICIENT_TABLE_PROPERTY.MISSING_KEY", + condition = "INSUFFICIENT_TABLE_PROPERTY.MISSING_KEY", parameters = Map("key" -> toSQLConf("spark.sql.sources.schema")) ) @@ -1417,7 +1417,7 @@ class MetastoreDataSourcesSuite extends QueryTest exception = intercept[AnalysisException] { sharedState.externalCatalog.getTable("default", "t2") }, - errorClass = "INSUFFICIENT_TABLE_PROPERTY.MISSING_KEY_PART", + condition = "INSUFFICIENT_TABLE_PROPERTY.MISSING_KEY_PART", parameters = Map( "key" -> toSQLConf("spark.sql.sources.schema.part.1"), "totalAmountOfParts" -> "3") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 55be6102a8535..0b10829f66910 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -272,7 +272,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle df.write.format("parquet").saveAsTable("`d:b`.`t:a`") } checkError(e, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> "`spark_catalog`.`d:b`")) } @@ -281,7 +281,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle df.write.format("parquet").saveAsTable("`d:b`.`table`") } checkError(e, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> "`spark_catalog`.`d:b`")) } @@ -297,7 +297,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle |) """.stripMargin) } - checkError(e, errorClass = "INVALID_SCHEMA_OR_RELATION_NAME", + checkError(e, condition = "INVALID_SCHEMA_OR_RELATION_NAME", parameters = Map("name" -> "`t:a`")) } @@ -313,7 +313,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle """.stripMargin) } checkError(e, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> "`spark_catalog`.`d:b`")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 7dc7fc41dc708..9c2f4461ff263 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -668,7 +668,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto exception = intercept[AnalysisException] { sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS fakeColumn") }, - errorClass = "COLUMN_NOT_FOUND", + condition = "COLUMN_NOT_FOUND", parameters = Map( "colName" -> "`fakeColumn`", "caseSensitiveConfig" -> "\"spark.sql.caseSensitive\"" @@ -1706,7 +1706,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto exception = intercept[AnalysisException] { sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS value") }, - errorClass = "UNSUPPORTED_FEATURE.ANALYZE_UNSUPPORTED_COLUMN_TYPE", + condition = "UNSUPPORTED_FEATURE.ANALYZE_UNSUPPORTED_COLUMN_TYPE", parameters = Map( "columnType" -> "\"MAP\"", "columnName" -> "`value`", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index b60adfb6f4cf1..07f212d2dcabb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -186,7 +186,7 @@ class HiveClientSuite(version: String) extends HiveVersionSuite(version) { assert(false, "dropDatabase should throw HiveException") } checkError(ex, - errorClass = "SCHEMA_NOT_EMPTY", + condition = "SCHEMA_NOT_EMPTY", parameters = Map("schemaName" -> "`temporary`")) client.dropDatabase("temporary", ignoreIfNotExists = false, cascade = true) @@ -485,7 +485,7 @@ class HiveClientSuite(version: String) extends HiveVersionSuite(version) { client.createPartitions("default", "src_part", partitions, ignoreIfExists = false) } checkError(e, - errorClass = "PARTITIONS_ALREADY_EXIST", + condition = "PARTITIONS_ALREADY_EXIST", parameters = Map("partitionList" -> "PARTITION (`key1` = 101, `key2` = 102)", "tableName" -> "`default`.`src_part`")) } finally { @@ -577,7 +577,7 @@ class HiveClientSuite(version: String) extends HiveVersionSuite(version) { exception = intercept[AnalysisException] { versionSpark.table("mv1").collect() }, - errorClass = "UNSUPPORTED_FEATURE.HIVE_TABLE_TYPE", + condition = "UNSUPPORTED_FEATURE.HIVE_TABLE_TYPE", parameters = Map( "tableName" -> "`mv1`", "tableType" -> "materialized view" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 4b000fff0eb92..7c9b0b7781427 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -526,7 +526,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te |GROUP BY key """.stripMargin) }, - errorClass = "AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION", + condition = "AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION", parameters = Map("sqlExpr" -> "\"mydoublesum(((value + (1.5 * key)) + rand()))\""), context = ExpectedContext( fragment = "value + 1.5 * key + rand()", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 15dbd6aaa5b06..69d54a746b55d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -399,14 +399,14 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql("CREATE TABLE tab1 USING hive") }, - errorClass = "_LEGACY_ERROR_TEMP_3083", + condition = "_LEGACY_ERROR_TEMP_3083", parameters = Map("tableName" -> "`spark_catalog`.`default`.`tab1`") ) checkError( exception = intercept[AnalysisException] { sql(s"CREATE TABLE tab2 USING hive location '${tempDir.getCanonicalPath}'") }, - errorClass = "_LEGACY_ERROR_TEMP_3083", + condition = "_LEGACY_ERROR_TEMP_3083", parameters = Map("tableName" -> "`spark_catalog`.`default`.`tab2`") ) } @@ -530,7 +530,7 @@ class HiveDDLSuite } test("create table: partition column names exist in table definition") { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( "CREATE TABLE tbl(a int) PARTITIONED BY (a string)", "COLUMN_ALREADY_EXISTS", Map("columnName" -> "`a`")) @@ -542,7 +542,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "COLUMN_NOT_DEFINED_IN_TABLE", + condition = "COLUMN_NOT_DEFINED_IN_TABLE", parameters = Map( "colType" -> "partition", "colName" -> "`b`", @@ -605,7 +605,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "_LEGACY_ERROR_TEMP_1076", + condition = "_LEGACY_ERROR_TEMP_1076", parameters = Map( "details" -> "The spec ([partCol1=]) contains an empty partition column value") ) @@ -657,7 +657,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE $externalTab DROP PARTITION (ds='2008-04-09', unknownCol='12')") }, - errorClass = "_LEGACY_ERROR_TEMP_1231", + condition = "_LEGACY_ERROR_TEMP_1231", parameters = Map( "key" -> "unknownCol", "tblName" -> s"`$SESSION_CATALOG_NAME`.`default`.`exttable_with_partitions`") @@ -770,13 +770,12 @@ class HiveDDLSuite } } - private def assertAnalysisErrorClass( + private def assertAnalysisErrorCondition( sqlText: String, - errorClass: String, + condition: String, parameters: Map[String, String]): Unit = { val e = intercept[AnalysisException](sql(sqlText)) - checkError(e, - errorClass = errorClass, parameters = parameters) + checkError(e, condition = condition, parameters = parameters) } test("create table - SET TBLPROPERTIES EXTERNAL to TRUE") { @@ -787,7 +786,7 @@ class HiveDDLSuite sql(s"CREATE TABLE $tabName (height INT, length INT) " + s"TBLPROPERTIES('EXTERNAL'='TRUE')") }, - errorClass = "_LEGACY_ERROR_TEMP_3087", + condition = "_LEGACY_ERROR_TEMP_3087", parameters = Map.empty ) } @@ -804,7 +803,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE $tabName SET TBLPROPERTIES ('EXTERNAL' = 'TRUE')") }, - errorClass = "_LEGACY_ERROR_TEMP_3087", + condition = "_LEGACY_ERROR_TEMP_3087", parameters = Map.empty ) // The table type is not changed to external @@ -836,7 +835,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER VIEW $tabName RENAME TO $newViewName") }, - errorClass = "_LEGACY_ERROR_TEMP_1253", + condition = "_LEGACY_ERROR_TEMP_1253", parameters = Map.empty ) @@ -844,7 +843,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE $oldViewName RENAME TO $newViewName") }, - errorClass = "_LEGACY_ERROR_TEMP_1252", + condition = "_LEGACY_ERROR_TEMP_1252", parameters = Map.empty ) @@ -852,7 +851,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER VIEW $tabName SET TBLPROPERTIES ('p' = 'an')") }, - errorClass = "EXPECT_VIEW_NOT_TABLE.USE_ALTER_TABLE", + condition = "EXPECT_VIEW_NOT_TABLE.USE_ALTER_TABLE", parameters = Map( "tableName" -> s"`$SESSION_CATALOG_NAME`.`default`.`$tabName`", "operation" -> "ALTER VIEW ... SET TBLPROPERTIES"), @@ -863,7 +862,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE $oldViewName SET TBLPROPERTIES ('p' = 'an')") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", + condition = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", parameters = Map( "viewName" -> s"`$SESSION_CATALOG_NAME`.`default`.`$oldViewName`", "operation" -> "ALTER TABLE ... SET TBLPROPERTIES"), @@ -874,7 +873,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER VIEW $tabName UNSET TBLPROPERTIES ('p')") }, - errorClass = "EXPECT_VIEW_NOT_TABLE.USE_ALTER_TABLE", + condition = "EXPECT_VIEW_NOT_TABLE.USE_ALTER_TABLE", parameters = Map( "tableName" -> s"`$SESSION_CATALOG_NAME`.`default`.`$tabName`", "operation" -> "ALTER VIEW ... UNSET TBLPROPERTIES"), @@ -885,7 +884,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE $oldViewName UNSET TBLPROPERTIES ('p')") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", + condition = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", parameters = Map( "viewName" -> s"`$SESSION_CATALOG_NAME`.`default`.`$oldViewName`", "operation" -> "ALTER TABLE ... UNSET TBLPROPERTIES"), @@ -896,7 +895,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE $oldViewName SET LOCATION '/path/to/home'") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> s"`$SESSION_CATALOG_NAME`.`default`.`$oldViewName`", "operation" -> "ALTER TABLE ... SET LOCATION ..."), @@ -907,7 +906,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE $oldViewName SET SERDE 'whatever'") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", + condition = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", parameters = Map( "viewName" -> s"`$SESSION_CATALOG_NAME`.`default`.`$oldViewName`", "operation" -> "ALTER TABLE ... SET [SERDE|SERDEPROPERTIES]"), @@ -918,7 +917,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE $oldViewName SET SERDEPROPERTIES ('x' = 'y')") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", + condition = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", parameters = Map( "viewName" -> s"`$SESSION_CATALOG_NAME`.`default`.`$oldViewName`", "operation" -> "ALTER TABLE ... SET [SERDE|SERDEPROPERTIES]"), @@ -929,7 +928,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE $oldViewName PARTITION (a=1, b=2) SET SERDEPROPERTIES ('x' = 'y')") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", + condition = "EXPECT_TABLE_NOT_VIEW.USE_ALTER_VIEW", parameters = Map( "viewName" -> s"`$SESSION_CATALOG_NAME`.`default`.`$oldViewName`", "operation" -> "ALTER TABLE ... SET [SERDE|SERDEPROPERTIES]"), @@ -940,7 +939,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE $oldViewName RECOVER PARTITIONS") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> s"`$SESSION_CATALOG_NAME`.`default`.`$oldViewName`", "operation" -> "ALTER TABLE ... RECOVER PARTITIONS"), @@ -951,7 +950,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE $oldViewName PARTITION (a='1') RENAME TO PARTITION (a='100')") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> s"`$SESSION_CATALOG_NAME`.`default`.`$oldViewName`", "operation" -> "ALTER TABLE ... RENAME TO PARTITION"), @@ -962,7 +961,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE $oldViewName ADD IF NOT EXISTS PARTITION (a='4', b='8')") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> s"`$SESSION_CATALOG_NAME`.`default`.`$oldViewName`", "operation" -> "ALTER TABLE ... ADD PARTITION ..."), @@ -973,7 +972,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE $oldViewName DROP IF EXISTS PARTITION (a='2')") }, - errorClass = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", + condition = "EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE", parameters = Map( "viewName" -> s"`$SESSION_CATALOG_NAME`.`default`.`$oldViewName`", "operation" -> "ALTER TABLE ... DROP PARTITION ..."), @@ -1117,9 +1116,9 @@ class HiveDDLSuite test("drop table using drop view") { withTable("tab1") { sql("CREATE TABLE tab1(c1 int)") - assertAnalysisErrorClass( + assertAnalysisErrorCondition( sqlText = "DROP VIEW tab1", - errorClass = "WRONG_COMMAND_FOR_OBJECT_TYPE", + condition = "WRONG_COMMAND_FOR_OBJECT_TYPE", parameters = Map( "alternative" -> "DROP TABLE", "operation" -> "DROP VIEW", @@ -1136,9 +1135,9 @@ class HiveDDLSuite spark.range(10).write.saveAsTable("tab1") withView("view1") { sql("CREATE VIEW view1 AS SELECT * FROM tab1") - assertAnalysisErrorClass( + assertAnalysisErrorCondition( sqlText = "DROP TABLE view1", - errorClass = "WRONG_COMMAND_FOR_OBJECT_TYPE", + condition = "WRONG_COMMAND_FOR_OBJECT_TYPE", parameters = Map( "alternative" -> "DROP VIEW", "operation" -> "DROP TABLE", @@ -1159,7 +1158,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql("CREATE VIEW view1 (col1, col3) AS SELECT * FROM tab1") }, - errorClass = "CREATE_VIEW_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + condition = "CREATE_VIEW_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", parameters = Map( "viewName" -> s"`$SESSION_CATALOG_NAME`.`default`.`view1`", "viewColumns" -> "`col1`, `col3`", @@ -1175,7 +1174,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql("CREATE VIEW view2 (col1, col3) AS SELECT * FROM tab2") }, - errorClass = "CREATE_VIEW_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS", + condition = "CREATE_VIEW_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS", parameters = Map( "viewName" -> s"`$SESSION_CATALOG_NAME`.`default`.`view2`", "viewColumns" -> "`col1`, `col3`", @@ -1322,7 +1321,7 @@ class HiveDDLSuite sql(s"USE default") val sqlDropDatabase = s"DROP DATABASE $dbName ${if (cascade) "CASCADE" else "RESTRICT"}" if (tableExists && !cascade) { - assertAnalysisErrorClass( + assertAnalysisErrorCondition( sqlDropDatabase, "SCHEMA_NOT_EMPTY", Map("schemaName" -> s"`$dbName`")) @@ -1358,7 +1357,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql("DROP DATABASE default") }, - errorClass = "UNSUPPORTED_FEATURE.DROP_DATABASE", + condition = "UNSUPPORTED_FEATURE.DROP_DATABASE", parameters = Map("database" -> "`spark_catalog`.`default`") ) @@ -1368,7 +1367,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql("DROP DATABASE DeFault") }, - errorClass = caseSensitive match { + condition = caseSensitive match { case "false" => "UNSUPPORTED_FEATURE.DROP_DATABASE" case _ => "_LEGACY_ERROR_TEMP_3065" }, @@ -1764,7 +1763,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { spark.catalog.getTable("default", indexTabName) }, - errorClass = "UNSUPPORTED_FEATURE.HIVE_TABLE_TYPE", + condition = "UNSUPPORTED_FEATURE.HIVE_TABLE_TYPE", parameters = Map( "tableName" -> s"`$indexTabName`", "tableType" -> "index table") @@ -1774,7 +1773,7 @@ class HiveDDLSuite exception = intercept[TableAlreadyExistsException] { sql(s"CREATE TABLE $indexTabName(b int) USING hive") }, - errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> s"`default`.`$indexTabName`") ) @@ -1782,7 +1781,7 @@ class HiveDDLSuite exception = intercept[TableAlreadyExistsException] { sql(s"ALTER TABLE $tabName RENAME TO $indexTabName") }, - errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> s"`default`.`$indexTabName`") ) @@ -1791,7 +1790,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"DESCRIBE $indexTabName") }, - errorClass = "UNSUPPORTED_FEATURE.HIVE_TABLE_TYPE", + condition = "UNSUPPORTED_FEATURE.HIVE_TABLE_TYPE", parameters = Map( "tableName" -> s"`$indexTabName`", "tableType" -> "index table") @@ -1869,7 +1868,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"ALTER TABLE tbl SET TBLPROPERTIES ('${forbiddenPrefix}foo' = 'loser')") }, - errorClass = "_LEGACY_ERROR_TEMP_3086", + condition = "_LEGACY_ERROR_TEMP_3086", parameters = Map( "tableName" -> "spark_catalog.default.tbl", "invalidKeys" -> s"[${forbiddenPrefix}foo]") @@ -1878,7 +1877,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"CREATE TABLE tbl2 (a INT) TBLPROPERTIES ('${forbiddenPrefix}foo'='anything')") }, - errorClass = "_LEGACY_ERROR_TEMP_3086", + condition = "_LEGACY_ERROR_TEMP_3086", parameters = Map( "tableName" -> "spark_catalog.default.tbl2", "invalidKeys" -> s"[${forbiddenPrefix}foo]") @@ -1987,7 +1986,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { spark.table("t").write.format("hive").mode("overwrite").saveAsTable("t") }, - errorClass = "UNSUPPORTED_OVERWRITE.TABLE", + condition = "UNSUPPORTED_OVERWRITE.TABLE", parameters = Map("table" -> s"`$SESSION_CATALOG_NAME`.`default`.`t`")) } } @@ -2380,7 +2379,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql("INSERT INTO TABLE t SELECT 1") }, - errorClass = "_LEGACY_ERROR_TEMP_3065", + condition = "_LEGACY_ERROR_TEMP_3065", parameters = Map( "clazz" -> "java.lang.IllegalArgumentException", "msg" -> "java.net.URISyntaxException: Relative path in absolute URI: a:b") @@ -2427,7 +2426,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") }, - errorClass = "_LEGACY_ERROR_TEMP_3065", + condition = "_LEGACY_ERROR_TEMP_3065", parameters = Map( "clazz" -> "java.lang.IllegalArgumentException", "msg" -> "java.net.URISyntaxException: Relative path in absolute URI: a:b") @@ -2437,7 +2436,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") }, - errorClass = "_LEGACY_ERROR_TEMP_3065", + condition = "_LEGACY_ERROR_TEMP_3065", parameters = Map( "clazz" -> "java.lang.IllegalArgumentException", "msg" -> "java.net.URISyntaxException: Relative path in absolute URI: a:b") @@ -2526,13 +2525,13 @@ class HiveDDLSuite sql("CREATE TABLE tab (c1 int) PARTITIONED BY (c2 int) STORED AS PARQUET") if (!caseSensitive) { // duplicating partitioning column name - assertAnalysisErrorClass( + assertAnalysisErrorCondition( "ALTER TABLE tab ADD COLUMNS (C2 string)", "COLUMN_ALREADY_EXISTS", Map("columnName" -> "`c2`")) // duplicating data column name - assertAnalysisErrorClass( + assertAnalysisErrorCondition( "ALTER TABLE tab ADD COLUMNS (C1 string)", "COLUMN_ALREADY_EXISTS", Map("columnName" -> "`c1`")) @@ -2543,7 +2542,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql("ALTER TABLE tab ADD COLUMNS (C2 string)") }, - errorClass = "_LEGACY_ERROR_TEMP_3065", + condition = "_LEGACY_ERROR_TEMP_3065", parameters = Map( "clazz" -> "org.apache.hadoop.hive.ql.metadata.HiveException", "msg" -> "Partition column name c2 conflicts with table columns.") @@ -2555,7 +2554,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql("ALTER TABLE tab ADD COLUMNS (C1 string)") }, - errorClass = "_LEGACY_ERROR_TEMP_3065", + condition = "_LEGACY_ERROR_TEMP_3065", parameters = Map( "clazz" -> "org.apache.hadoop.hive.ql.metadata.HiveException", "msg" -> "Duplicate column name c1 in the table definition.") @@ -2573,7 +2572,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql("CREATE TABLE t1 USING PARQUET AS SELECT NULL AS null_col") }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`null_col`", "columnType" -> "\"VOID\"", @@ -2584,7 +2583,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql("CREATE TABLE t2 STORED AS PARQUET AS SELECT null as null_col") }, - errorClass = "_LEGACY_ERROR_TEMP_3065", + condition = "_LEGACY_ERROR_TEMP_3065", parameters = Map( "clazz" -> "org.apache.hadoop.hive.ql.metadata.HiveException", "msg" -> "java.lang.UnsupportedOperationException: Unknown field type: void") @@ -2600,7 +2599,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql("CREATE TABLE t1 (v VOID) USING PARQUET") }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`v`", "columnType" -> "\"VOID\"", @@ -2610,7 +2609,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql("CREATE TABLE t2 (v VOID) STORED AS PARQUET") }, - errorClass = "_LEGACY_ERROR_TEMP_3065", + condition = "_LEGACY_ERROR_TEMP_3065", parameters = Map( "clazz" -> "org.apache.hadoop.hive.ql.metadata.HiveException", "msg" -> "java.lang.UnsupportedOperationException: Unknown field type: void") @@ -2818,7 +2817,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql("load data inpath '/doesnotexist.csv' into table tbl") }, - errorClass = "LOAD_DATA_PATH_NOT_EXISTS", + condition = "LOAD_DATA_PATH_NOT_EXISTS", parameters = Map("path" -> "/doesnotexist.csv") ) } @@ -2860,7 +2859,7 @@ class HiveDDLSuite exception = intercept[SparkException] { sql(s"CREATE TABLE t (a $typ) USING hive") }, - errorClass = "CANNOT_RECOGNIZE_HIVE_TYPE", + condition = "CANNOT_RECOGNIZE_HIVE_TYPE", parameters = Map( "fieldType" -> toSQLType(replaced), "fieldName" -> "`a`") @@ -2878,7 +2877,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql(s"CREATE TABLE t (a $typ) USING hive") }, - errorClass = "_LEGACY_ERROR_TEMP_3065", + condition = "_LEGACY_ERROR_TEMP_3065", parameters = Map( "clazz" -> "org.apache.hadoop.hive.ql.metadata.HiveException", "msg" -> msg) @@ -2917,7 +2916,7 @@ class HiveDDLSuite |AS SELECT 1 as a, "a" as b""".stripMargin checkError( exception = intercept[ParseException](sql(sql1)), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map( "message" -> "Schema may not be specified in a Create Table As Select (CTAS) statement"), context = ExpectedContext(sql1, 0, 92)) @@ -2929,7 +2928,7 @@ class HiveDDLSuite |AS SELECT 1 as a, "a" as b""".stripMargin checkError( exception = intercept[ParseException](sql(sql2)), - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map( "message" -> "Partition column types may not be specified in Create Table As Select (CTAS)"), @@ -3020,7 +3019,7 @@ class HiveDDLSuite exception = intercept[ParseException] { sql(sql1) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map( "operation" -> ("CREATE TABLE LIKE ... USING ... ROW FORMAT SERDE " + "ORG.APACHE.HADOOP.HIVE.SERDE2.LAZY.LAZYSIMPLESERDE")), @@ -3036,7 +3035,7 @@ class HiveDDLSuite exception = intercept[ParseException] { sql(sql2) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map( "operation" -> ("CREATE TABLE LIKE ... USING ... ROW FORMAT SERDE " + "ORG.APACHE.HADOOP.HIVE.SERDE2.LAZY.LAZYSIMPLESERDE")), @@ -3052,7 +3051,7 @@ class HiveDDLSuite exception = intercept[ParseException] { sql(sql3) }, - errorClass = "_LEGACY_ERROR_TEMP_0047", + condition = "_LEGACY_ERROR_TEMP_0047", parameters = Map.empty, context = ExpectedContext(fragment = sql3, start = 0, stop = 153) ) @@ -3066,7 +3065,7 @@ class HiveDDLSuite exception = intercept[ParseException] { sql(sql4) }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map( "operation" -> ("CREATE TABLE LIKE ... USING ... STORED AS " + "INPUTFORMAT INFORMAT OUTPUTFORMAT OUTFORMAT ROW FORMAT " + @@ -3140,7 +3139,7 @@ class HiveDDLSuite exception = intercept[ParseException] { sql(sql1) }, - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map( "message" -> (s"ROW FORMAT SERDE is incompatible with format " + s"'${format.toLowerCase(Locale.ROOT)}', which also specifies a serde")), @@ -3179,7 +3178,7 @@ class HiveDDLSuite exception = intercept[ParseException] { sql(sql1) }, - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map( "message" -> ("ROW FORMAT DELIMITED is only compatible " + "with 'textfile', not 'parquet'")), @@ -3226,7 +3225,7 @@ class HiveDDLSuite spark.sql(s"INSERT OVERWRITE LOCAL DIRECTORY '${path.getCanonicalPath}' " + s"STORED AS $format SELECT ID, if(1=1, 1, 0), abs(id), '^-' FROM v") }.getCause.asInstanceOf[AnalysisException], - errorClass = "INVALID_COLUMN_NAME_AS_PATH", + condition = "INVALID_COLUMN_NAME_AS_PATH", parameters = Map( "datasource" -> "HiveFileFormat", "columnName" -> "`(IF((1 = 1), 1, 0))`" @@ -3254,7 +3253,7 @@ class HiveDDLSuite |FROM v """.stripMargin) }.getCause.asInstanceOf[AnalysisException], - errorClass = "INVALID_COLUMN_NAME_AS_PATH", + condition = "INVALID_COLUMN_NAME_AS_PATH", parameters = Map("datasource" -> "HiveFileFormat", "columnName" -> "`IF(ID=1,ID,0)`") ) } @@ -3276,7 +3275,7 @@ class HiveDDLSuite s"'org.apache.hadoop.hive.ql.udf.UDFUUID' USING JAR '$jar'") } checkError(e, - errorClass = "ROUTINE_ALREADY_EXISTS", + condition = "ROUTINE_ALREADY_EXISTS", parameters = Map("routineName" -> "`f1`", "newRoutineType" -> "routine", "existingRoutineType" -> "routine")) @@ -3305,7 +3304,7 @@ class HiveDDLSuite exception = intercept[SparkUnsupportedOperationException] { sql(sqlCmd) }, - errorClass = "UNSUPPORTED_FEATURE.HIVE_WITH_ANSI_INTERVALS", + condition = "UNSUPPORTED_FEATURE.HIVE_WITH_ANSI_INTERVALS", parameters = Map("tableName" -> s"`$SESSION_CATALOG_NAME`.`default`.`$tbl`") ) } @@ -3356,7 +3355,7 @@ class HiveDDLSuite exception = intercept[AnalysisException] { sql("CREATE TABLE tab (c1 int) PARTITIONED BY (c1) STORED AS PARQUET") }, - errorClass = "ALL_PARTITION_COLUMNS_NOT_ALLOWED", + condition = "ALL_PARTITION_COLUMNS_NOT_ALLOWED", parameters = Map.empty ) } @@ -3369,7 +3368,7 @@ class HiveDDLSuite sql(s"DELETE FROM $tbl WHERE c1 = 1") } checkError(e, - errorClass = "UNSUPPORTED_FEATURE.TABLE_OPERATION", + condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", parameters = Map( "tableName" -> s"`$SESSION_CATALOG_NAME`.`default`.`$tbl`", "operation" -> "DELETE") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index e5e180e7c135c..42fc50e5b163b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -78,7 +78,7 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd exception = intercept[ParseException] { body }, - errorClass = "INVALID_STATEMENT_OR_CLAUSE", + condition = "INVALID_STATEMENT_OR_CLAUSE", parameters = Map("operation" -> operation), context = expectedContext) } @@ -683,7 +683,7 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd exception = intercept[AnalysisException] { sql("SELECT (CASE WHEN key > 2 THEN 3 WHEN 1 THEN 2 ELSE 0 END) FROM src").collect() }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"CASE WHEN (key > 2) THEN 3 WHEN 1 THEN 2 ELSE 0 END\"", "paramIndex" -> "second", @@ -819,7 +819,7 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd """ALTER TABLE alter1 SET SERDE 'org.apache.hadoop.hive.serde2.TestSerDe' |WITH serdeproperties('s1'='9')""".stripMargin) }, - errorClass = "_LEGACY_ERROR_TEMP_3065", + condition = "_LEGACY_ERROR_TEMP_3065", parameters = Map( "clazz" -> "org.apache.hadoop.hive.ql.metadata.HiveException", "msg" -> "at least one column must be specified for the table")) @@ -1251,7 +1251,7 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd """INSERT INTO TABLE dp_test PARTITION(dp) |SELECT key, value, key % 5 FROM src""".stripMargin) }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", + condition = "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`dp_test`", "tableColumns" -> "`key`, `value`, `dp`, `sp`", @@ -1265,7 +1265,7 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd """INSERT INTO TABLE dp_test PARTITION(dp, sp = 1) |SELECT key, value, key % 5 FROM src""".stripMargin) }, - errorClass = "_LEGACY_ERROR_TEMP_3079", + condition = "_LEGACY_ERROR_TEMP_3079", parameters = Map.empty) } } @@ -1368,7 +1368,7 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd exception = intercept[AnalysisException] { sql("select * from test_b") }, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`test_b`"), context = ExpectedContext( fragment = "test_b", @@ -1382,7 +1382,7 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd exception = intercept[AnalysisException] { s2.sql("select * from test_a") }, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> "`test_a`"), context = ExpectedContext( fragment = "test_a", @@ -1408,7 +1408,7 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd exception = intercept[AnalysisException] { sql("USE not_existing_db") }, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> "`spark_catalog`.`not_existing_db`") ) } @@ -1420,7 +1420,7 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd exception = intercept[AnalysisException] { range(1).selectExpr("not_a_udf()") }, - errorClass = "UNRESOLVED_ROUTINE", + condition = "UNRESOLVED_ROUTINE", sqlState = None, parameters = Map( "routineName" -> "`not_a_udf`", @@ -1437,7 +1437,7 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd exception = intercept[AnalysisException] { range(1).selectExpr("not_a_udf()") }, - errorClass = "UNRESOLVED_ROUTINE", + condition = "UNRESOLVED_ROUTINE", sqlState = None, parameters = Map( "routineName" -> "`not_a_udf`", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index d7d859f57e5b6..df6ef57a581d0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -48,7 +48,7 @@ class HiveResolutionSuite extends HiveComparisonTest { exception = intercept[AnalysisException] { sql("SELECT a[0].b from nested").queryExecution.analyzed }, - errorClass = "AMBIGUOUS_REFERENCE_TO_FIELDS", + condition = "AMBIGUOUS_REFERENCE_TO_FIELDS", sqlState = "42000", parameters = Map("field" -> "`b`", "count" -> "2") ) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala index 9d86c72f86afd..8ec3dd6dffa14 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala @@ -87,7 +87,7 @@ class HiveSQLViewSuite extends SQLViewSuite with TestHiveSingleton { } checkError( exception = e, - errorClass = "INVALID_TEMP_OBJ_REFERENCE", + condition = "INVALID_TEMP_OBJ_REFERENCE", parameters = Map( "obj" -> "VIEW", "objName" -> s"`$SESSION_CATALOG_NAME`.`default`.`view1`", @@ -213,7 +213,7 @@ class HiveSQLViewSuite extends SQLViewSuite with TestHiveSingleton { exception = intercept[AnalysisException] { sql("SHOW CREATE TABLE v1") }, - errorClass = "UNSUPPORTED_SHOW_CREATE_TABLE.WITH_UNSUPPORTED_FEATURE", + condition = "UNSUPPORTED_SHOW_CREATE_TABLE.WITH_UNSUPPORTED_FEATURE", sqlState = "0A000", parameters = Map( "tableName" -> s"`$SESSION_CATALOG_NAME`.`default`.`v1`", @@ -224,7 +224,7 @@ class HiveSQLViewSuite extends SQLViewSuite with TestHiveSingleton { exception = intercept[AnalysisException] { sql("SHOW CREATE TABLE v1 AS SERDE") }, - errorClass = "UNSUPPORTED_SHOW_CREATE_TABLE.WITH_UNSUPPORTED_FEATURE", + condition = "UNSUPPORTED_SHOW_CREATE_TABLE.WITH_UNSUPPORTED_FEATURE", sqlState = "0A000", parameters = Map( "tableName" -> s"`$SESSION_CATALOG_NAME`.`default`.`v1`", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeReadWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeReadWriteSuite.scala index aafc4764d2465..1922144a92efa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeReadWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeReadWriteSuite.scala @@ -44,7 +44,7 @@ class HiveSerDeReadWriteSuite extends QueryTest with SQLTestUtils with TestHiveS super.beforeAll() originalConvertMetastoreParquet = spark.conf.get(CONVERT_METASTORE_PARQUET.key) originalConvertMetastoreORC = spark.conf.get(CONVERT_METASTORE_ORC.key) - originalORCImplementation = spark.conf.get(ORC_IMPLEMENTATION) + originalORCImplementation = spark.conf.get(ORC_IMPLEMENTATION.key) spark.conf.set(CONVERT_METASTORE_PARQUET.key, "false") spark.conf.set(CONVERT_METASTORE_ORC.key, "false") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index 4aadd710b42a7..9bf84687c8f51 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -151,7 +151,7 @@ class HiveUDAFSuite extends QueryTest exception = intercept[AnalysisException] { sql("SELECT testUDAFPercentile(x, rand()) from view1 group by y") }, - errorClass = "AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION", + condition = "AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION", parameters = Map("sqlExpr" -> "\"testUDAFPercentile( x, rand())\""), context = ExpectedContext( fragment = "rand()", @@ -181,7 +181,7 @@ class HiveUDAFSuite extends QueryTest exception = intercept[AnalysisException] { sql(s"SELECT $functionName(100)") }, - errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", parameters = Map( "functionName" -> toSQLId("longProductSum"), "expectedNum" -> "2", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 2e88b13f0963d..6604fe2a9d61e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -591,7 +591,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { exception = intercept[AnalysisException] { sql("SELECT dAtABaSe1.unknownFunc(1)") }, - errorClass = "UNRESOLVED_ROUTINE", + condition = "UNRESOLVED_ROUTINE", parameters = Map( "routineName" -> "`dAtABaSe1`.`unknownFunc`", "searchPath" -> @@ -790,7 +790,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { checkError( exception = intercept[SparkException](df.collect()), - errorClass = "FAILED_EXECUTE_UDF", + condition = "FAILED_EXECUTE_UDF", parameters = Map( "functionName" -> "`org`.`apache`.`spark`.`sql`.`hive`.`execution`.`SimpleUDFAssertTrue`", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala index 98801e0b0273a..0c54381551bf8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala @@ -107,7 +107,7 @@ class Hive_2_1_DDLSuite extends SparkFunSuite with TestHiveSingleton { "CREATE TABLE t1 (c1 string) USING parquet", StructType(Array(StructField("c2", IntegerType)))) }, - errorClass = "_LEGACY_ERROR_TEMP_3065", + condition = "_LEGACY_ERROR_TEMP_3065", parameters = Map( "clazz" -> "org.apache.hadoop.hive.ql.metadata.HiveException", "msg" -> ("Unable to alter table. " + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 05b73e31d1156..14051034a588e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -79,13 +79,13 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi test("query global temp view") { val df = Seq(1).toDF("i1") df.createGlobalTempView("tbl1") - val global_temp_db = spark.conf.get(GLOBAL_TEMP_DATABASE) + val global_temp_db = spark.conf.get(GLOBAL_TEMP_DATABASE.key) checkAnswer(spark.sql(s"select * from ${global_temp_db}.tbl1"), Row(1)) spark.sql(s"drop view ${global_temp_db}.tbl1") } test("non-existent global temp view") { - val global_temp_db = spark.conf.get(GLOBAL_TEMP_DATABASE) + val global_temp_db = spark.conf.get(GLOBAL_TEMP_DATABASE.key) val e = intercept[AnalysisException] { spark.sql(s"select * from ${global_temp_db}.nonexistentview") } @@ -221,7 +221,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi val sqlText = "describe functioN abcadf" checkError( exception = intercept[AnalysisException](sql(sqlText)), - errorClass = "UNRESOLVED_ROUTINE", + condition = "UNRESOLVED_ROUTINE", parameters = Map( "routineName" -> "`abcadf`", "searchPath" -> "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`default`]"), @@ -246,7 +246,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi checkKeywordsExist(sql("describe function `between`"), "Function: between", - "Usage: input [NOT] BETWEEN lower AND upper - " + + "input [NOT] between lower AND upper - " + "evaluate if `input` is [not] in between `lower` and `upper`") checkKeywordsExist(sql("describe function `case`"), @@ -1356,7 +1356,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi exception = intercept[AnalysisException] { sql(s"select id from parquet.`invalid_path`") }, - errorClass = "PATH_NOT_FOUND", + condition = "PATH_NOT_FOUND", parameters = Map("path" -> "file.*invalid_path"), matchPVals = true ) @@ -1413,7 +1413,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi exception = intercept[AnalysisException] { sql(s"select id from hive.`${f.getCanonicalPath}`") }, - errorClass = "UNSUPPORTED_DATASOURCE_FOR_DIRECT_QUERY", + condition = "UNSUPPORTED_DATASOURCE_FOR_DIRECT_QUERY", parameters = Map("dataSourceType" -> "hive"), context = ExpectedContext(s"hive.`${f.getCanonicalPath}`", 15, 21 + f.getCanonicalPath.length) @@ -1424,7 +1424,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi exception = intercept[AnalysisException] { sql(s"select id from HIVE.`${f.getCanonicalPath}`") }, - errorClass = "UNSUPPORTED_DATASOURCE_FOR_DIRECT_QUERY", + condition = "UNSUPPORTED_DATASOURCE_FOR_DIRECT_QUERY", parameters = Map("dataSourceType" -> "HIVE"), context = ExpectedContext(s"HIVE.`${f.getCanonicalPath}`", 15, 21 + f.getCanonicalPath.length) @@ -1782,7 +1782,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi |AS SELECT 1 AS a, 2 AS b """.stripMargin) }, - errorClass = "_LEGACY_ERROR_TEMP_0035", + condition = "_LEGACY_ERROR_TEMP_0035", parameters = Map("message" -> "Column ordering must be ASC, was 'DESC'"), context = ExpectedContext( fragment = "CLUSTERED BY (a) SORTED BY (b DESC) INTO 2 BUCKETS", @@ -2638,7 +2638,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi sql("CREATE TABLE t (a STRING)") checkError( exception = intercept[AnalysisException](sql("INSERT INTO t SELECT a*2 FROM t where b=1")), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map("objectName" -> "`b`", "proposal" -> "`a`"), context = ExpectedContext( @@ -2648,7 +2648,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi checkError( exception = intercept[AnalysisException]( sql("INSERT INTO t SELECT cast(a as short) FROM t where b=1")), - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = None, parameters = Map("objectName" -> "`b`", "proposal" -> "`a`"), context = ExpectedContext( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala index 2eff462faa8dc..ce1b41ecc6dd7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala @@ -234,7 +234,7 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi |GROUP BY key """.stripMargin) }, - errorClass = "AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION", + condition = "AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION", parameters = Map("sqlExpr" -> "\"mydoublesum(((value + (1.5 * key)) + rand()))\""), context = ExpectedContext( fragment = "value + 1.5 * key + rand()", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/AlterNamespaceSetLocationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/AlterNamespaceSetLocationSuite.scala index 1dbe405b217e5..232916b6e05b2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/AlterNamespaceSetLocationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/AlterNamespaceSetLocationSuite.scala @@ -36,7 +36,7 @@ class AlterNamespaceSetLocationSuite extends v1.AlterNamespaceSetLocationSuiteBa exception = intercept[AnalysisException] { sql(s"ALTER DATABASE $ns SET LOCATION 'loc'") }, - errorClass = "_LEGACY_ERROR_TEMP_1219", + condition = "_LEGACY_ERROR_TEMP_1219", parameters = Map.empty ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/AlterTableAddColumnsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/AlterTableAddColumnsSuite.scala index 3ae2ff562d102..521ad759c302d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/AlterTableAddColumnsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/AlterTableAddColumnsSuite.scala @@ -35,7 +35,7 @@ class AlterTableAddColumnsSuite exception = intercept[SparkUnsupportedOperationException] { sql(s"ALTER TABLE $tbl ADD COLUMNS (ym INTERVAL YEAR)") }, - errorClass = "UNSUPPORTED_FEATURE.HIVE_WITH_ANSI_INTERVALS", + condition = "UNSUPPORTED_FEATURE.HIVE_WITH_ANSI_INTERVALS", parameters = Map("tableName" -> toSQLId(tbl)) ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/CreateNamespaceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/CreateNamespaceSuite.scala index 12e41a569b346..cc54469a52f3f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/CreateNamespaceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/CreateNamespaceSuite.scala @@ -33,7 +33,7 @@ class CreateNamespaceSuite extends v1.CreateNamespaceSuiteBase with CommandSuite exception = intercept[AnalysisException] { sql(s"CREATE NAMESPACE $catalog.$namespace") }, - errorClass = "REQUIRES_SINGLE_PART_NAMESPACE", + condition = "REQUIRES_SINGLE_PART_NAMESPACE", parameters = Map( "sessionCatalog" -> catalog, "namespace" -> "`ns1`.`ns2`" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowCreateTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowCreateTableSuite.scala index 4c6252128094f..8e654d28cd033 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowCreateTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowCreateTableSuite.scala @@ -365,7 +365,7 @@ class ShowCreateTableSuite extends v1.ShowCreateTableSuiteBase with CommandSuite exception = intercept[AnalysisException] { checkCreateSparkTableAsHive("t1") }, - errorClass = "UNSUPPORTED_SHOW_CREATE_TABLE.WITH_UNSUPPORTED_SERDE_CONFIGURATION", + condition = "UNSUPPORTED_SHOW_CREATE_TABLE.WITH_UNSUPPORTED_SERDE_CONFIGURATION", sqlState = "0A000", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`t1`", @@ -438,7 +438,7 @@ class ShowCreateTableSuite extends v1.ShowCreateTableSuiteBase with CommandSuite exception = intercept[AnalysisException] { sql("SHOW CREATE TABLE t1") }, - errorClass = "UNSUPPORTED_SHOW_CREATE_TABLE.ON_TRANSACTIONAL_HIVE_TABLE", + condition = "UNSUPPORTED_SHOW_CREATE_TABLE.ON_TRANSACTIONAL_HIVE_TABLE", sqlState = "0A000", parameters = Map("tableName" -> "`spark_catalog`.`default`.`t1`") ) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala index 64a7731a3bf84..4c6218c6366c8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala @@ -58,7 +58,7 @@ class HiveOrcQuerySuite extends OrcQueryTest with TestHiveSingleton { exception = intercept[AnalysisException] { spark.read.orc(path) }, - errorClass = "UNABLE_TO_INFER_SCHEMA", + condition = "UNABLE_TO_INFER_SCHEMA", parameters = Map("format" -> "ORC") ) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index bac48f6c0c018..c1084dd4ee7ff 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -117,7 +117,7 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { exception = intercept[AnalysisException] { sql("select interval 1 days").write.mode("overwrite").orc(orcDir) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`INTERVAL '1' DAY`", "columnType" -> "\"INTERVAL DAY\"", @@ -128,7 +128,7 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { exception = intercept[AnalysisException] { sql("select null").write.mode("overwrite").orc(orcDir) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`NULL`", "columnType" -> "\"VOID\"", @@ -140,7 +140,7 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { spark.udf.register("testType", () => new IntervalData()) sql("select testType()").write.mode("overwrite").orc(orcDir) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`testType()`", "columnType" -> "UDT(\"INTERVAL\")", @@ -154,7 +154,7 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { spark.range(1).write.mode("overwrite").orc(orcDir) spark.read.schema(schema).orc(orcDir).collect() }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`a`", "columnType" -> "\"INTERVAL\"", @@ -167,7 +167,7 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { spark.range(1).write.mode("overwrite").orc(orcDir) spark.read.schema(schema).orc(orcDir).collect() }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`a`", "columnType" -> "UDT(\"INTERVAL\")", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 86401bf923927..56f835b53a75d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -246,7 +246,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes testDF.write.format(dataSourceName) .mode(SaveMode.ErrorIfExists).save(file.getCanonicalPath) }, - errorClass = "PATH_ALREADY_EXISTS", + condition = "PATH_ALREADY_EXISTS", parameters = Map("outputPath" -> "file:.*"), matchPVals = true ) @@ -354,7 +354,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .partitionBy("p1", "p2") .save(file.getCanonicalPath) }, - errorClass = "PATH_ALREADY_EXISTS", + condition = "PATH_ALREADY_EXISTS", parameters = Map("outputPath" -> "file:.*"), matchPVals = true )