diff --git a/README.md b/README.md index ca021cd..1f78728 100644 --- a/README.md +++ b/README.md @@ -2,16 +2,27 @@ A Datasource on top of Spark Datasource V1 APIs, that provides Spark support for [Hive ACID transactions](https://cwiki.apache.org/confluence/display/Hive/Hive+Transactions). -This datasource provides the capability to work with Hive ACID V2 tables, both Full ACID tables as well as Insert-Only tables. Currently, it supports reading from these ACID tables only, and ability to write will be added in the near future. +This datasource provides the capability to work with Hive ACID V2 tables, both Full ACID tables as well as Insert-Only tables. + +functionality availability matrix + +Functionality | Full ACID table | Insert Only Table | +------------- | --------------- | ----------------- | +READ | >= v0.4.0 | >= v0.4.0 | +INSERT INTO / OVERWRITE | >= v0.4.3 | >=0.4.5 | +CTAS | >= v0.4.3 | >=0.4.5 | +UPDATE | >=0.4.5 | Not Supported | +DELETE | >=0.4.5 | Not Supported | +MERGE | Not Supported | Not Supported | + +*Note: In case of insert only table for support of write operation compatibility check needs to be disabled* ## Quick Start These are the pre-requisites to using this library: -1. You already have Hive ACID tables (ACID V2) and need to read it from Spark (as currently write is not _NOT_ supported). -2. You have Hive Metastore DB with version 3.0.0 or higher. Please refer to [Hive Metastore](https://cwiki.apache.org/confluence/display/Hive/Design#Design-MetastoreArchitecture) for details. -3. You have a Hive Metastore Server running with version 3.0.0 or higher, as Hive ACID needs a standalone Hive Metastore Server to operate. Please refer to [Hive Configuration](https://cwiki.apache.org/confluence/display/Hive/Hive+Transactions#HiveTransactions-Configuration) for configuration options. -4. You are using the above Hive Metastore Server with your Spark for its metastore communications. +1. You have Hive Metastore DB with version 3.1.2 or higher. Please refer to [Hive Metastore](https://cwiki.apache.org/confluence/display/Hive/Design#Design-MetastoreArchitecture) for details. +2. You have a Hive Metastore Server running with version 3.1.1 or higher, as Hive ACID needs a standalone Hive Metastore Server to operate. Please refer to [Hive Configuration](https://cwiki.apache.org/confluence/display/Hive/Hive+Transactions#HiveTransactions-Configuration) for configuration options. ### Config @@ -32,47 +43,112 @@ Change configuration in `$SPARK_HOME/conf/hive-site.xml` to point to already con There are a few ways to use the library while running spark-shell -1. Use the published package + `spark-shell --packages qubole:spark-acid:0.4.2-s_2.11 - spark-shell --packages qubole:spark-acid:0.4.0-s_2.11 - -2. If you built the jar yourself, copy the `spark-acid-assembly-0.4.0.jar` jar into `$SPARK_HOME/assembly/target/scala.2_11/jars` and run +2. If you built the jar yourself, copy the `spark-acid-assembly-0.4.2.jar` jar into `$SPARK_HOME/assembly/target/scala.2_11/jars` and run spark-shell - + #### Scala/Python -To read the acid table from Scala / pySpark, the table can be directly accessed using this datasource. -Note the short name of this datasource is `HiveAcid` +To operate on Hive ACID table from Scala / pySpark, the table can be directly accessed using this datasource. Note the short name of this datasource is `HiveAcid`. Hive ACID table are tables in HiveMetastore so any operation of read and/or write needs `format("HiveAcid").option("table", """)`. _Direct read and write from the file is not supported_ scala> val df = spark.read.format("HiveAcid").options(Map("table" -> "default.acidtbl")).load() scala> df.collect() #### SQL -To read an existing Hive acid table through pure SQL, you need to create a dummy table that acts as a symlink to the -original acid table. This symlink is required to instruct Spark to use this datasource against an existing table. -To create the symlink table +To read an existing Hive acid table through pure SQL, there are two ways: + +1. Create a dummy table that acts as a symlink to the original acid table. This symlink is required to instruct Spark to use this datasource against an existing table. - scala> spark.sql("create table symlinkacidtable using HiveAcid options ('table' 'default.acidtbl')") + To create the symlink table: -_NB: This will produce a warning indicating that Hive does not understand this format_ + spark.sql("create table symlinkacidtable using HiveAcid options ('table' 'default.acidtbl')") - WARN hive.HiveExternalCatalog: Couldn’t find corresponding Hive SerDe for data source provider com.qubole.spark.datasources.hiveacid.HiveAcidDataSource. Persisting data source table `default`.`sparkacidtbl` into Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. + spark.sql("select * from symlinkacidtable") + + + _NB: This will produce a warning indicating that Hive does not understand this format_ + + WARN hive.HiveExternalCatalog: Couldn’t find corresponding Hive SerDe for data source provider com.qubole.spark.hiveacid.datasource.HiveAcidDataSource. Persisting data source table `default`.`sparkacidtbl` into Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. _Please ignore it, as this is a sym table for Spark to operate with and no underlying storage._ -To read the table data: +2. Use SparkSession extensions framework to add a new Analyzer rule (HiveAcidAutoConvert) to Spark Analyser. This analyzer rule automatically converts an _HiveTableRelation_ representing acid table to _LogicalRelation_ backed by HiveAcidRelation. - scala> var df = spark.sql("select * from symlinkacidtable") - scala> df.collect() + To use this, initialize SparkSession with the extension builder as mentioned below: + val spark = SparkSession.builder() + .appName("Hive-acid-test") + .config("spark.sql.extensions", "com.qubole.spark.hiveacid.HiveAcidAutoConvertExtension") + .enableHiveSupport() + . + .getOrCreate() -## Latest Binaries + spark.sql("select * from default.acidtbl") + +#### Example + +##### Create Hive ACID Table + +Drop Existing table + + spark.sql("Drop table if exists aciddemo.t_scala_simple") + +Create table + + spark.sql("CREATE TABLE aciddemo.t_scala_simple (status BOOLEAN, tweet ARRAY, rank DOUBLE, username STRING) STORED AS ORC TBLPROPERTIES('TRANSACTIONAL' = 'true')") + +Check if it is transactional + + spark.sql("DESCRIBE extended aciddemo.t_scala_simple").show() + + +##### Scala + +Read Existing table and insert into acid table + + val df = spark.read.format("HiveAcid").options(Map("table" -> "aciddemo.acidtbl")).load() + df.write.format("HiveAcid").option("table", "aciddemo.t_scala_simple").mode("append").save() + +Read Existing table and insert overwrite acid table -ACID datasource is published spark-packages.org. The latest version of the binary is `0.4.0` + val df = spark.read.format("HiveAcid").options(Map("table" -> "aciddemo.acidtbl")).load() + df.write.format("HiveAcid").option("table", "aciddemo.t_scala_simple").mode("overwrite").save() +_Note: User cannot operate directly on file level data as table is required when reading and writing transactionally. +`df.write.format("HiveAcid").mode("overwrite").save("s3n://aciddemo/api/warehouse/aciddemo.db/random")` won't work_ + +Read acid table + + val df = spark.read.format("HiveAcid").options(Map("table" -> "aciddemo.t_scala_simple")).load() + df.select("status", "rank").filter($"rank" > "20").show() + +##### SQL + +Insert into the table select as + + spark.sql("INSERT INTO aciddemo.t_sql_simple select * from aciddemo.acidtbl") + +Insert overwrite the table select as + + spark.sql("INSERT OVERWRITE TABLE aciddemo.t_sql_simple select * from aciddemo.acidtbl") + +Insert into" + + spark.sql("INSERT INTO aciddemo.t_sql_simple VALUES(false, array("test"), 11.2, 'qubole')") + +Read + + spark.sql("SELECT status, rank from aciddemo.t_sql_simple where rank > 20") + + +## Latest Binaries + +ACID datasource is published spark-packages.org. The latest version of the binary is `0.4.2` + ## Version Compatibility ### Compatibility with Apache Spark Versions @@ -90,21 +166,18 @@ _NB: Hive ACID V2 is supported in Hive 3.0.0 onwards and for that hive Metastore ## Developer resources ### Build -1. First, build the dependencies and publish it to local. The *shaded-dependencies* sub-project is an sbt project to create the shaded hive metastore and hive exec jars combined into a fat jar `spark-acid-shaded-dependencies`. This is required due to our dependency on Hive 3 for Hive ACID, and Spark currently only supports Hive 1.2 - -To compile and publish shaded dependencies jar: +1. First, build the dependencies and publish it to local. The *shaded-dependencies* sub-project is an sbt project to create the shaded hive metastore and hive exec jars combined into a fat jar `spark-acid-shaded-dependencies`. This is required due to our dependency on Hive 3 for Hive ACID, and Spark currently only supports Hive 1.2. To compile and publish shaded dependencies jar: cd shaded-dependencies sbt clean publishLocal - + 2. Next, build the main project: - cd ../ - sbt assembly + sbt assembly -This will create the `spark-acid-assembly-0.4.0.jar` which can be now used in your application. +This will create the `spark-acid-assembly-0.4.2.jar` which can be now used in your application. -### Test +### Test Tests are run against a standalone docker setup. Please refer to [Docker setup] (docker/README.md) to build and start a container. _NB: Container run HMS server, HS2 Server and HDFS and listens on port 10000,10001 and 9000 respectively. So stop if you are running HMS or HDFS on same port on host machine._ @@ -129,32 +202,20 @@ Read more about [sbt release](https://github.com/sbt/sbt-release) ### Design Constraints -Hive ACID works with locks, where every client that is operating on ACID tables is expected to acquire locks for the duration of reads and writes. This datasource however does not acquire read locks. When it needs to read data, it talks to the HiveMetaStore Server to get the list of transactions that have been committed, and using that, the list of files it should read from the filesystem. But it does not lock the table or partition for the duration of the read. - -Because it does not acquire read locks, there is a chance that the data being read could get deleted by Hive's ACID management(perhaps because the data was ready to be cleaned up due to compaction). To avoid this scenario which can read to query failures, we recommend that you disable automatic compaction and cleanup in Hive on the tables that you are going to be reading using this datasource, and recommend that the compaction and cleanup be done when you know that no users are reading those tables. Ideally, we would have wanted to just disable automatic cleanup and let the compaction happen, but there is no way in Hive today to just disable cleanup and it is tied to compaction, so we recommend to disable compaction. - -You have a few options available to you to disable automatic compaction: - -1. Disable automatic compaction globally, i.e. for all ACID tables: To do this, we recommend you set the following compaction thresholds on the Hive Metastore Server to a very high number(like 1000000 below) so that compaction never gets initiated automatically and can only be initiated manually. - - hive.compactor.delta.pct.threshold=1000000 - hive.compactor.delta.num.threshold=1000000 - -2. Disable automatic compaction for selected ACID tables: To do this, you can set a table property using the ALTER TABLE command: +1. This datasource when it needs to read data, it talks to the HiveMetaStore Server to get the list of transactions that have been committed, and using that, the list of files it should read from the filesystem (_uses s3 listing_). Given the snapshot of list of file is created by using listing, to avoid inconsistent copy of data, on cloud object store service like S3 guard should be used. - ALTER TABLE <> SET TBLPROPERTIES ("NO_AUTO_COMPACTION"="true") +2. This snapshot of list of files is created at the RDD level. These snapshot are at the RDD level so even when using same table in single SQL it may be operating on two different snapshots -This will disable automatic compaction on a particular table, and you can use this approach if you have a limited set of ACID tables that you intend to access using this datasource. + spark.sql("select * from a join a) -Once you have disabled automatic compaction either globally or on a particular set of tables, you can chose to run compaction manually at a desired time when you know there are no readers reading these acid tables, using an ALTER TABLE command: +3. The files in the snapshot needs to be protected till the RDD is in use. By design concurrent reads and writes on the Hive ACID works with the help of locks, where every client (across multiple engines) that is operating on ACID tables is expected to acquire locks for the duration of reads and writes. The lifetime of RDD can be very long, to avoid blocking other operations like inserts this datasource _DOES NOT_ acquire lock but uses an alternative mechanism to protect reads. Other way the snapshot can be protected is by making sure the files in the snapshot are not deleted while in use. For the current datasoure any table on which Spark is operating `Automatic Compaction` should be disabled. This makes sure that cleaner does not clean any file. To disable automatic compaction on table - ALTER TABLE table_name [PARTITION (partition_key = 'partition_value' [, ...])] COMPACT 'compaction_type'[AND WAIT] [WITH OVERWRITE TBLPROPERTIES ("property"="value" [, ...])]; + ALTER TABLE <> SET TBLPROPERTIES ("NO_AUTO_COMPACTION"="true") -compaction_type are either `MAJOR` or `MINOR` + When the table is not in use cleaner can be enabled and all the files that needs cleaned will get queued up for cleaner. Disabling compaction do have performance implication on reads/writes as lot of delta file may need to be merged when performing read. -More details on the above commands and their variations available [here](https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL). +4. Note that even though reads are protected admin operation like `TRUNCATE` `ALTER TABLE DROP COLUMN` and `DROP` have no protection as they clean files with intevention from cleaner. These operations should be performed when Spark is not using the table. -We are looking into removing this restriction, and hope to be able to fix this in the near future. ## Contributing @@ -162,4 +223,4 @@ We use [Github Issues](https://github.com/qubole/spark-acid/issues) to track iss ## Reporting bugs or feature requests -Please use the github issues for the spark-acid project to report issues or raise feature requests. +Please use the github issues for the acid-ds project to report issues or raise feature requests. diff --git a/build.sbt b/build.sbt index 37cac67..fc3e42c 100644 --- a/build.sbt +++ b/build.sbt @@ -81,7 +81,9 @@ excludeDependencies ++= Seq ( // orc "org.apache.orc" % "orc-core", - "org.apache.orc" % "orc-mapreduce" + "org.apache.orc" % "orc-mapreduce", + + "org.slf4j" % "slf4j-api" ) // do not run test at assembly diff --git a/docker/beeline b/docker/beeline new file mode 100755 index 0000000..26072a8 --- /dev/null +++ b/docker/beeline @@ -0,0 +1,8 @@ +#!/bin/bash + +name="spark-hiveacid-test-container" + +docker exec -it $name bin/bash -c "\ + . ~/.bashrc; \ + export HADOOP_HOME=/hadoop; \ + hive/bin/beeline -u jdbc:hive2://0.0.0.0:10001/default root root" diff --git a/docker/files/hive-site.xml b/docker/files/hive-site.xml index 4d83ed1..dafe0a0 100644 --- a/docker/files/hive-site.xml +++ b/docker/files/hive-site.xml @@ -91,4 +91,10 @@ hive.auto.convert.join false + + + hive.stats.autogather + false + + diff --git a/docker/spark-shell b/docker/spark-shell new file mode 100755 index 0000000..6ed7222 --- /dev/null +++ b/docker/spark-shell @@ -0,0 +1,66 @@ +#!/bin/bash +if [ -z ${2} ] +then + echo "Specify the spark-acid jar location" + echo "spark-shell ~/codeline/TOT ~/codeline/TOT/acid-ds/target/scala-2.11/spark-acid-qds-assembly-0.4.3.jar" + exit +fi +if [ -z ${1} ] +then + echo "Specify and spark code base directory" + echo "spark-shell ~/codeline/TOT ~/codeline/TOT/acid-ds/target/scala-2.11/spark-acid-qds-assembly-0.4.3.jar" + exit +fi + +shellenv() { + export QENV_LOCAL_CODELINE="${1}" + export QENV_LOCAL_CONF="${QENV_LOCAL}/conf" + export HADOOP_SRC="${QENV_LOCAL_CODELINE}/hadoop2" + export SPARK_SRC="${QENV_LOCAL_CODELINE}/spark" + export HUSTLER_SRC="${QENV_LOCAL_CODELINE}/hustler" + export HIVE_SRC="${QENV_LOCAL_CODELINE}/hive" + export ZEPPELIN_SRC="${QENV_LOCAL_CODELINE}/zeppelin" +} + +hsnapshot() { + HADOOP_SNAPSHOT=`ls ${HADOOP_SRC}/hadoop-dist/target/hadoop* | grep SNAPSHOT: | cut -d':' -f1` +} + +hivesnapshot() { + loc=`ls ${HIVE_SRC}/packaging/target/apache-hive* |grep bin |grep -v ':'` + HIVE_SNAPSHOT=${HIVE_SRC}/packaging/target/${loc}/${loc}/ +} + +run_spark_shelllocal() { + + # Setup writest into spark-env file. Run spark-shell after it. + echo "Update Spark Conf based on Hadoop Build Version --> ${SPARK_SRC}/conf/spark-env.sh" + hsnapshot + hivesnapshot + + str="export SPARK_YARN_USER_ENV=CLASSPATH=${QENV_LOCAL_CONF}/" + echo ${str} > ${SPARK_SRC}/conf/spark-env.sh + + if [ -n "${HADOOP_SNAPSHOT}" ] + then + + str="export SPARK_DIST_CLASSPATH=${QENV_LOCAL_CONF}/:${HADOOP_SNAPSHOT}/share/hadoop/common/lib/*:${HADOOP_SNAPSHOT}/share/hadoop/common/*:${HADOOP_SNAPSHOT}/share/hadoop/hdfs:${HADOOP_SNAPSHOT}/share/hadoop/hdfs/lib/*:${HADOOP_SNAPSHOT}/share/hadoop/hdfs/*:${HADOOP_SNAPSHOT}/share/hadoop/yarn/lib/*:${HADOOP_SNAPSHOT}/share/hadoop/yarn/*:${HADOOP_SNAPSHOT}/share/hadoop/mapreduce/*:/share/hadoop/tools:${HADOOP_SNAPSHOT}/share/hadoop/tools/lib/*:${HADOOP_SNAPSHOT}/share/hadoop/tools/*:/share/hadoop/qubole:${HADOOP_SNAPSHOT}/share/hadoop/qubole/*" + echo ${str} >> ${SPARK_SRC}/conf/spark-env.sh + fi + + if [ -n "${HIVE_SNAPSHOT}" ] + then + str="export SPARK_DIST_CLASSPATH=\${SPARK_DIST_CLASSPATH}:${HIVE_SNAPSHOT}/lib/*" + echo ${str} >> ${SPARK_SRC}/conf/spark-env.sh + fi + + str="export HADOOP_CONF_DIR=${QENV_LOCAL_CONF}/" + echo ${str} >> ${SPARK_SRC}/conf/spark-env.sh + + $SPARK_SRC/bin/spark-shell $@ +} + + +shellenv ${1} +shift +run_spark_shelllocal --jars $@ --conf spark.sql.extensions=com.qubole.spark.datasources.hiveacid.HiveAcidAutoConvertExtension --conf spark.hadoop.hive.metastore.uris=thrift://localhost:10000 --conf spark.sql.catalogImplementation=hive diff --git a/shaded-dependencies/build.sbt b/shaded-dependencies/build.sbt index c6ad1f4..6110854 100644 --- a/shaded-dependencies/build.sbt +++ b/shaded-dependencies/build.sbt @@ -46,14 +46,27 @@ libraryDependencies ++= Seq( "org.apache.hive" % "hive-jdbc" % hive_version intransitive(), "org.apache.hive" % "hive-service" % hive_version intransitive(), "org.apache.hive" % "hive-serde" % hive_version intransitive(), - "org.apache.hive" % "hive-common" % hive_version intransitive() + "org.apache.hive" % "hive-common" % hive_version intransitive(), + + // To deal with hive3 metastore library 0.9.3 vs zeppelin thirft + // library version 0.9.1 conflict when runing Notebooks. + "org.apache.thrift" % "libfb303" % "0.9.3", + "org.apache.thrift" % "libthrift" % "0.9.3" ) assemblyShadeRules in assembly := Seq( ShadeRule.rename("org.apache.hadoop.hive.**" -> "com.qubole.shaded.hadoop.hive.@1").inAll, ShadeRule.rename("org.apache.hive.**" -> "com.qubole.shaded.hive.@1").inAll, ShadeRule.rename("org.apache.orc.**" -> "com.qubole.shaded.orc.@1").inAll, - ShadeRule.rename("com.google.**" -> "com.qubole.shaded.@1").inAll + ShadeRule.rename("org.apache.commons.**" -> "com.qubole.shaded.commons.@1").inAll, + ShadeRule.rename("org.apache.avro.**" -> "com.qubole.shaded.avro.@1").inAll, + ShadeRule.rename("org.apache.parquet.**" -> "com.qubole.shaded.parquet.@1").inAll, + ShadeRule.rename("org.apache.http.**" -> "com.qubole.shaded.http.@1").inAll, + ShadeRule.rename("org.apache.tez.**" -> "com.qubole.shaded.tez.@1").inAll, + + ShadeRule.rename("com.google.**" -> "com.qubole.shaded.@1").inAll, + ShadeRule.rename("com.facebook.fb303.**" -> "com.qubole.shaded.facebook.fb303.@1").inAll, + ShadeRule.rename("org.apache.thrift.**" -> "com.qubole.shaded.thrift.@1").inAll ) import sbtassembly.AssemblyPlugin.autoImport.ShadeRule diff --git a/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 4c91d95..6737a5c 100644 --- a/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -17,4 +17,4 @@ # limitations under the License. # -com.qubole.spark.datasources.hiveacid.HiveAcidDataSource +com.qubole.spark.hiveacid.datasource.HiveAcidDataSource diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/HiveAcidErrors.scala b/src/main/scala/com/qubole/spark/datasources/hiveacid/HiveAcidErrors.scala deleted file mode 100644 index 4d3eb5c..0000000 --- a/src/main/scala/com/qubole/spark/datasources/hiveacid/HiveAcidErrors.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright 2019 Qubole, Inc. All rights reserved. - * - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.qubole.spark.datasources.hiveacid - -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - -object HiveAcidErrors { - def tableNotSpecifiedException: Throwable = { - new IllegalArgumentException("'table' is not specified") - } - - def tableNotAcidException: Throwable = { - new IllegalArgumentException("The specified table is not an acid table") - } - - def validWriteIdsNotInitialized: Throwable = { - new RuntimeException("Valid WriteIds not initialized") - } - -} - -class AnalysisException ( - val message: String, - val line: Option[Int] = None, - val startPosition: Option[Int] = None, - // Some plans fail to serialize due to bugs in scala collections. - @transient val plan: Option[LogicalPlan] = None, - val cause: Option[Throwable] = None) - extends Exception(message, cause.orNull) with Serializable { - - def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = { - val newException = new AnalysisException(message, line, startPosition) - newException.setStackTrace(getStackTrace) - newException - } - - override def getMessage: String = { - val planAnnotation = Option(plan).flatten.map(p => s";\n$p").getOrElse("") - getSimpleMessage + planAnnotation - } - - // Outputs an exception without the logical plan. - // For testing only - def getSimpleMessage: String = { - val lineAnnotation = line.map(l => s" line $l").getOrElse("") - val positionAnnotation = startPosition.map(p => s" pos $p").getOrElse("") - s"$message;$lineAnnotation$positionAnnotation" - } -} diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/HiveAcidRelation.scala b/src/main/scala/com/qubole/spark/datasources/hiveacid/HiveAcidRelation.scala deleted file mode 100644 index fa7bd66..0000000 --- a/src/main/scala/com/qubole/spark/datasources/hiveacid/HiveAcidRelation.scala +++ /dev/null @@ -1,296 +0,0 @@ -/* - * Copyright 2019 Qubole, Inc. All rights reserved. - * - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.qubole.spark.datasources.hiveacid - -import java.util.Locale -import java.util.concurrent.TimeUnit - -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.io.Output -import com.qubole.shaded.hadoop.hive.conf.HiveConf -import com.qubole.shaded.hadoop.hive.metastore.api.{FieldSchema, Table} -import com.qubole.shaded.hadoop.hive.ql.metadata -import com.qubole.shaded.hadoop.hive.ql.metadata.Hive -import com.qubole.shaded.hadoop.hive.ql.plan.TableDesc -import com.qubole.shaded.orc.mapreduce.OrcInputFormat -import com.qubole.spark.datasources.hiveacid.orc.OrcFilters -import com.qubole.spark.datasources.hiveacid.rdd.HiveTableReader -import com.qubole.shaded.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.commons.codec.binary.Base64 -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.serde2.ColumnProjectionUtils -import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.PrettyAttribute -import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} -import org.apache.spark.sql.sources.{BaseRelation, Filter, PrunedFilteredScan, _} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Row, SQLContext} - -import scala.collection.JavaConversions._ -import scala.collection.mutable - -class HiveAcidRelation(var sqlContext: SQLContext, - parameters: Map[String, String]) - extends BaseRelation - with PrunedFilteredScan - with Logging { - - private val tableName: String = parameters.getOrElse("table", { - throw HiveAcidErrors.tableNotSpecifiedException - }) - - private val hiveConf: HiveConf = HiveAcidRelation.createHiveConf(sqlContext.sparkContext) - - private val hTable: metadata.Table = { - // Currently we are creating and closing a connection to the hive metastore every time we need to do something. - // This can be optimized. - val hive: Hive = Hive.get(hiveConf) - val hTable = hive.getTable(tableName.split('.')(0), tableName.split('.')(1)) - Hive.closeCurrent() - hTable - } - - if (hTable.getParameters.get("transactional") != "true") { - throw HiveAcidErrors.tableNotAcidException - } - var isFullAcidTable: Boolean = hTable.getParameters.containsKey("transactional_properties") && - !hTable.getParameters.get("transactional_properties").equals("insert_only") - logInfo("Insert Only table: " + !isFullAcidTable) - - val dataSchema = StructType(hTable.getSd.getCols.toList.map(fromHiveColumn).toArray) - val partitionSchema = StructType(hTable.getPartitionKeys.toList.map(fromHiveColumn).toArray) - - override val schema: StructType = { - val overlappedPartCols = mutable.Map.empty[String, StructField] - partitionSchema.foreach { partitionField => - if (dataSchema.exists(getColName(_) == getColName(partitionField))) { - overlappedPartCols += getColName(partitionField) -> partitionField - } - } - StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ - partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) - } - - override def sizeInBytes: Long = { - val compressionFactor = sqlContext.sparkSession.sessionState.conf.fileCompressionFactor - (sqlContext.sparkSession.sessionState.conf.defaultSizeInBytes * compressionFactor).toLong - } - - override val needConversion: Boolean = false - - private def getColName(f: StructField): String = { - if (sqlContext.sparkSession.sessionState.conf.caseSensitiveAnalysis) { - f.name - } else { - f.name.toLowerCase(Locale.ROOT) - } - } - - private def fromHiveColumn(hc: FieldSchema): StructField = { - val columnType = getSparkSQLDataType(hc) - val metadata = if (hc.getType != columnType.catalogString) { - new MetadataBuilder().putString(HIVE_TYPE_STRING, hc.getType).build() - } else { - Metadata.empty - } - - val field = StructField( - name = hc.getName, - dataType = columnType, - nullable = true, - metadata = metadata) - Option(hc.getComment).map(field.withComment).getOrElse(field) - } - - /** Get the Spark SQL native DataType from Hive's FieldSchema. */ - private def getSparkSQLDataType(hc: FieldSchema): DataType = { - try { - CatalystSqlParser.parseDataType(hc.getType) - } catch { - case e: ParseException => - throw new SparkException("Cannot recognize hive type string: " + hc.getType, e) - } - } - - override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { - val tableDesc = new TableDesc(hTable.getInputFormatClass, hTable.getOutputFormatClass, - hTable.getMetadata) - val partitionedColumnSet = hTable.getPartitionKeys.map(_.getName).toSet - val requiredNonPartitionedColumns = requiredColumns.filter(x => !partitionedColumnSet.contains(x)) - val requiredHiveFields = requiredColumns.map(x => hTable.getAllCols.find(_.getName == x).get) - val requiredAttributes = requiredHiveFields.map { x => - PrettyAttribute(x.getName, getSparkSQLDataType(x)) - } - val partitionAttributes = hTable.getPartitionKeys.map { x => - PrettyAttribute(x.getName, getSparkSQLDataType(x)) - } - - val hadoopConf = sqlContext.sparkSession.sessionState.newHadoopConf() - val (partitionFilters, otherFilters) = filters.partition { predicate => - !predicate.references.isEmpty && - predicate.references.toSet.subsetOf(hTable.getPartColNames.toSet) - } - val dataFilters = otherFilters.filter(_ - .references.intersect(hTable.getPartColNames).isEmpty - ) - logInfo(s"total filters : ${filters.size}: " + - s"dataFilters: ${dataFilters.size} " + - s"partitionFilters: ${partitionFilters.size}") - - setPushDownFiltersInHadoopConf(hadoopConf, dataFilters) - setRequiredColumnsInHadoopConf(hadoopConf, requiredNonPartitionedColumns) - - logDebug(s"sarg.pushdown: ${hadoopConf.get("sarg.pushdown")}," + - s"hive.io.file.readcolumn.names: ${hadoopConf.get("hive.io.file.readcolumn.names")}, " + - s"hive.io.file.readcolumn.ids: ${hadoopConf.get("hive.io.file.readcolumn.ids")}") - - val acidState = new HiveAcidState(sqlContext.sparkSession, hiveConf, hTable, - sqlContext.sparkSession.sessionState.conf.defaultSizeInBytes, partitionSchema, isFullAcidTable) - - val hiveReader = new HiveTableReader( - requiredAttributes, - partitionAttributes, - tableDesc, - sqlContext.sparkSession, - acidState, - hadoopConf) - if (hTable.isPartitioned) { - val requiredPartitions = getRawPartitions(partitionFilters) - hiveReader.makeRDDForPartitionedTable(requiredPartitions).asInstanceOf[RDD[Row]] - } else { - hiveReader.makeRDDForTable(hTable).asInstanceOf[RDD[Row]] - } - } - - private def setRequiredColumnsInHadoopConf(conf: Configuration, requiredColumns: Seq[String]): Unit = { - val dataCols: Seq[String] = hTable.getCols.map(_.getName) - val requiredColumnIndexes = requiredColumns.map(a => dataCols.indexOf(a): Integer) - val (sortedIDs, sortedNames) = requiredColumnIndexes.zip(requiredColumns).sorted.unzip - conf.set(ColumnProjectionUtils.READ_ALL_COLUMNS, "false") - conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, sortedNames.mkString(",")) - conf.set(ColumnProjectionUtils.READ_COLUMN_IDS_CONF_STR, sortedIDs.mkString(",")) - } - - private def setPushDownFiltersInHadoopConf(conf: Configuration, dataFilters: Array[Filter]): Unit = { - if (isPredicatePushdownEnabled()) { - OrcFilters.createFilter(dataSchema, dataFilters).foreach { f => - def toKryo(obj: com.qubole.shaded.hadoop.hive.ql.io.sarg.SearchArgument): String = { - val out = new Output(4 * 1024, 10 * 1024 * 1024) - new Kryo().writeObject(out, obj) - out.close() - return Base64.encodeBase64String(out.toBytes) - } - - logDebug(s"searchArgument: ${f}") - conf.set("sarg.pushdown", toKryo(f)) - conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) - } - } - } - - private def isPredicatePushdownEnabled(): Boolean = { - val sqlConf = sqlContext.sparkSession.sessionState.conf - sqlConf.getConfString("spark.sql.acidDs.enablePredicatePushdown", "true") == "true" - } - - - private def convertFilters(table: Table, filters: Seq[Filter]): String = { - def convertInToOr(name: String, values: Seq[Any]): String = { - values.map(value => s"$name = $value").mkString("(", " or ", ")") - } - - def convert(filter: Filter): Option[String] = filter match { - case In (name, values) => - Some(convertInToOr(name, values)) - - case EqualTo(name, value) => - Some(s"$name = $value") - - case GreaterThan(name, value) => - Some(s"$name > $value") - - case GreaterThanOrEqual(name, value) => - Some(s"$name >= $value") - - case LessThan(name, value) => - Some(s"$name < $value") - - case LessThanOrEqual(name, value) => - Some(s"$name <= $value") - - case And(filter1, filter2) => - val converted = convert(filter1) ++ convert(filter2) - if (converted.isEmpty) { - None - } else { - Some(converted.mkString("(", " and ", ")")) - } - - case Or(filter1, filter2) => - for { - left <- convert(filter1) - right <- convert(filter2) - } yield s"($left or $right)" - - case _ => None - } - - filters.flatMap(convert).mkString(" and ") - } - - - def getRawPartitions(partitionFilters: Array[Filter]): Seq[metadata.Partition] = { - val prunedPartitions = - if (sqlContext.sparkSession.sessionState.conf.metastorePartitionPruning && - partitionFilters.size > 0) { - val normalizedFilters = convertFilters(hTable.getTTable, partitionFilters) - val hive: Hive = Hive.get(hiveConf) - val hT = hive.getPartitionsByFilter(hTable, normalizedFilters) - Hive.closeCurrent() - hT - } else { - val hive: Hive = Hive.get(hiveConf) - val hT = hive.getPartitions(hTable) - Hive.closeCurrent() - hT - } - logDebug(s"partition count = ${prunedPartitions.size()}") - prunedPartitions.toSeq - } -} - - -object HiveAcidRelation extends Logging { - def createHiveConf(sparkContext: SparkContext): HiveConf = { - val hiveConf = new HiveConf() - (sparkContext.hadoopConfiguration.iterator().map(kv => kv.getKey -> kv.getValue) - ++ sparkContext.getConf.getAll.toMap).foreach { case (k, v) => - logDebug( - s""" - |Applying Hadoop/Hive/Spark and extra properties to Hive Conf: - |$k=${if (k.toLowerCase(Locale.ROOT).contains("password")) "xxx" else v} - """.stripMargin) - hiveConf.set(k, v) - } - hiveConf - } -} diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/HiveAcidState.scala b/src/main/scala/com/qubole/spark/datasources/hiveacid/HiveAcidState.scala deleted file mode 100644 index 6169e00..0000000 --- a/src/main/scala/com/qubole/spark/datasources/hiveacid/HiveAcidState.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright 2019 Qubole, Inc. All rights reserved. - * - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.qubole.spark.datasources.hiveacid - -import com.qubole.shaded.hadoop.hive.common.{ValidTxnWriteIdList, ValidWriteIdList} -import com.qubole.shaded.hadoop.hive.conf.HiveConf -import com.qubole.shaded.hadoop.hive.metastore.HiveMetaStoreClient -import com.qubole.shaded.hadoop.hive.metastore.txn.TxnUtils -import com.qubole.shaded.hadoop.hive.ql.metadata -import org.apache.hadoop.fs.Path -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.types.StructType - -import scala.collection.JavaConversions._ - -class HiveAcidState(sparkSession: SparkSession, - val hiveConf: HiveConf, - val table: metadata.Table, - val sizeInBytes: Long, - val pSchema: StructType, - val isFullAcidTable: Boolean) extends Logging { - - val location: Path = table.getDataLocation - private val dbName: String = table.getDbName - private val tableName: String = table.getTableName - private val txnId: Long = -1 - private var validWriteIdsNoTxn: ValidWriteIdList = _ - - def beginRead(): Unit = { - // Get write ids to read. Currently, this data source does not open a transaction or take locks against - // it's read entities(partitions). This can be enhanced in the future - val client = new HiveMetaStoreClient(hiveConf, null, false) - val validTxns = client.getValidTxns() - val txnWriteIds: ValidTxnWriteIdList = TxnUtils.createValidTxnWriteIdList(txnId, - client.getValidWriteIds(Seq(dbName + "." + tableName), - validTxns.writeToString())) - validWriteIdsNoTxn = txnWriteIds.getTableValidWriteIdList(table.getDbName + "." + table.getTableName) - client.close() - } - - def end(): Unit = { - // no op for now. If we start taking locks in the future, this can be implemented to release the locks and - // close the transaction - } - - def getValidWriteIds: ValidWriteIdList = { - if (validWriteIdsNoTxn == null) { - throw HiveAcidErrors.validWriteIdsNotInitialized - } - validWriteIdsNoTxn - } -} diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/util/InputFileBlockHolder.scala b/src/main/scala/com/qubole/spark/datasources/hiveacid/util/InputFileBlockHolder.scala deleted file mode 100644 index 943b71e..0000000 --- a/src/main/scala/com/qubole/spark/datasources/hiveacid/util/InputFileBlockHolder.scala +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright 2019 Qubole, Inc. All rights reserved. - * - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.qubole.spark.datasources.hiveacid.util - -import org.apache.spark.unsafe.types.UTF8String - -object InputFileBlockHolder { - /** - * A wrapper around some input file information. - * - * @param filePath path of the file read, or empty string if not available. - * @param startOffset starting offset, in bytes, or -1 if not available. - * @param length size of the block, in bytes, or -1 if not available. - */ - private class FileBlock(val filePath: UTF8String, val startOffset: Long, val length: Long) { - def this() { - this(UTF8String.fromString(""), -1, -1) - } - } - - /** - * The thread variable for the name of the current file being read. This is used by - * the InputFileName function in Spark SQL. - */ - private[this] val inputBlock: InheritableThreadLocal[FileBlock] = - new InheritableThreadLocal[FileBlock] { - override protected def initialValue(): FileBlock = new FileBlock - } - - /** - * Returns the holding file name or empty string if it is unknown. - */ - def getInputFilePath: UTF8String = inputBlock.get().filePath - - /** - * Returns the starting offset of the block currently being read, or -1 if it is unknown. - */ - def getStartOffset: Long = inputBlock.get().startOffset - - /** - * Returns the length of the block being read, or -1 if it is unknown. - */ - def getLength: Long = inputBlock.get().length - - /** - * Sets the thread-local input block. - */ - def set(filePath: String, startOffset: Long, length: Long): Unit = { - require(filePath != null, "filePath cannot be null") - require(startOffset >= 0, s"startOffset ($startOffset) cannot be negative") - require(length >= 0, s"length ($length) cannot be negative") - inputBlock.set(new FileBlock(UTF8String.fromString(filePath), startOffset, length)) - } - - /** - * Clears the input file block to default value. - */ - def unset(): Unit = inputBlock.remove() -} diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/util/NextIterator.scala b/src/main/scala/com/qubole/spark/datasources/hiveacid/util/NextIterator.scala deleted file mode 100644 index 55b9616..0000000 --- a/src/main/scala/com/qubole/spark/datasources/hiveacid/util/NextIterator.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright 2019 Qubole, Inc. All rights reserved. - * - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.qubole.spark.datasources.hiveacid.util - -abstract class NextIterator[U] extends Iterator[U] { - - private var gotNext = false - private var nextValue: U = _ - private var closed = false - protected var finished = false - - /** - * Method for subclasses to implement to provide the next element. - * - * If no next element is available, the subclass should set `finished` - * to `true` and may return any value (it will be ignored). - * - * This convention is required because `null` may be a valid value, - * and using `Option` seems like it might create unnecessary Some/None - * instances, given some iterators might be called in a tight loop. - * - * @return U, or set 'finished' when done - */ - protected def getNext(): U - - /** - * Method for subclasses to implement when all elements have been successfully - * iterated, and the iteration is done. - * - * Note: `NextIterator` cannot guarantee that `close` will be - * called because it has no control over what happens when an exception - * happens in the user code that is calling hasNext/next. - * - * Ideally you should have another try/catch, as in HadoopRDD, that - * ensures any resources are closed should iteration fail. - */ - protected def close() - - /** - * Calls the subclass-defined close method, but only once. - * - * Usually calling `close` multiple times should be fine, but historically - * there have been issues with some InputFormats throwing exceptions. - */ - def closeIfNeeded() { - if (!closed) { - // Note: it's important that we set closed = true before calling close(), since setting it - // afterwards would permit us to call close() multiple times if close() threw an exception. - closed = true - close() - } - } - - override def hasNext: Boolean = { - if (!finished) { - if (!gotNext) { - nextValue = getNext() - if (finished) { - closeIfNeeded() - } - gotNext = true - } - } - !finished - } - - override def next(): U = { - if (!hasNext) { - throw new NoSuchElementException("End of stream") - } - gotNext = false - nextValue - } -} diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/.gitignore b/src/main/scala/com/qubole/spark/hiveacid/.gitignore similarity index 100% rename from src/main/scala/com/qubole/spark/datasources/hiveacid/.gitignore rename to src/main/scala/com/qubole/spark/hiveacid/.gitignore diff --git a/src/main/scala/com/qubole/spark/hiveacid/HiveAcidAutoConvert.scala b/src/main/scala/com/qubole/spark/hiveacid/HiveAcidAutoConvert.scala new file mode 100644 index 0000000..58a8cb5 --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/HiveAcidAutoConvert.scala @@ -0,0 +1,73 @@ +/* + * Copyright 2019 Qubole, Inc. All rights reserved. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid + +import java.util.Locale + +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, Filter} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.datasources.LogicalRelation + +import com.qubole.spark.hiveacid.datasource.HiveAcidDataSource + + +/** + * Analyzer rule to convert a transactional HiveRelation + * into LogicalRelation backed by HiveAcidRelation + * @param spark - spark session + */ +case class HiveAcidAutoConvert(spark: SparkSession) extends Rule[LogicalPlan] { + + private def isConvertible(relation: HiveTableRelation): Boolean = { + val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + relation.tableMeta.properties.getOrElse("transactional", "false").toBoolean + } + + private def convert(relation: HiveTableRelation): LogicalRelation = { + val options = relation.tableMeta.properties ++ + relation.tableMeta.storage.properties ++ Map("table" -> relation.tableMeta.qualifiedName) + + val newRelation = new HiveAcidDataSource().createRelation(spark.sqlContext, options) + LogicalRelation(newRelation, isStreaming = false) + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + plan resolveOperators { + // Write path + case InsertIntoTable(r: HiveTableRelation, partition, query, overwrite, ifPartitionNotExists) + if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && isConvertible(r) => + InsertIntoTable(convert(r), partition, query, overwrite, ifPartitionNotExists) + + // Read path + case relation: HiveTableRelation + if DDLUtils.isHiveTable(relation.tableMeta) && isConvertible(relation) => + convert(relation) + } + } +} + +class HiveAcidAutoConvertExtension extends (SparkSessionExtensions => Unit) { + def apply(e: SparkSessionExtensions): Unit = { + e.injectResolutionRule(HiveAcidAutoConvert.apply) + } +} diff --git a/src/main/scala/com/qubole/spark/hiveacid/HiveAcidErrors.scala b/src/main/scala/com/qubole/spark/hiveacid/HiveAcidErrors.scala new file mode 100644 index 0000000..9c47477 --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/HiveAcidErrors.scala @@ -0,0 +1,114 @@ +/* + * Copyright 2019 Qubole, Inc. All rights reserved. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid + +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +object HiveAcidErrors { + def tableNotSpecifiedException(): Throwable = { + new IllegalArgumentException("'table' is not specified in parameters") + } + + def unsupportedFunction(): Throwable = { + new java.lang.UnsupportedOperationException() + } + + def invalidOperationType(operation: String): Throwable = { + new RuntimeException(s"Invalid operation type - $operation") + } + + def unsupportedSaveMode(saveMode: SaveMode): Throwable = { + new RuntimeException(s"Unsupported save mode - $saveMode") + } + def unsupportedOperationTypeInsertOnlyTable(operation: String): Throwable = { + new RuntimeException(s"Unsupported operation type - $operation for InsertOnly tables") + } + + def tableNotAcidException(tableName: String): Throwable = { + new IllegalArgumentException(s"table $tableName is not an acid table") + } + + def couldNotAcquireLockException(exception: Exception = null): Throwable = { + new RuntimeException(s"Could not acquire lock.", exception) + } + + def couldNotAcquireLockException(state: String): Throwable = { + new RuntimeException(s"Could not acquire lock. Lock State: $state") + } + + def txnAlreadyClosed(txnId: Long): Throwable = { + new RuntimeException(s"Transaction $txnId is already closed") + } + + def txnAlreadyOpen(txnId: Long): Throwable = { + new RuntimeException(s"Transaction already opened. Existing txnId: $txnId") + } + + def txnNotStarted(table: String): Throwable = { + new RuntimeException(s"Transaction on $table not started") + } + + def txnNoTransaction(): Throwable = { + new RuntimeException(s"No transaction found") + } + + def tableSnapshotNonExistent(snapshotId: Long): Throwable = { + new RuntimeException(s"Table snapshost $snapshotId does not exist") + } + + def tableWriteIdRequestedBeforeTxnStart(table: String): Throwable = { + new RuntimeException(s"Write id requested for table $table before txn was started") + } + + def repeatedTxnId(txnId: Long, activeTxns: Seq[Long]): Throwable = { + new RuntimeException(s"Repeated transaction id $txnId," + + s"active transactions are [${activeTxns.mkString(",")}]") + } +} + +class AnalysisException( + val message: String, + val line: Option[Int] = None, + val startPosition: Option[Int] = None, + // Some plans fail to serialize due to bugs in scala collections. + @transient val plan: Option[LogicalPlan] = None, + val cause: Option[Throwable] = None) + extends Exception(message, cause.orNull) with Serializable { + + def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = { + val newException = new AnalysisException(message, line, startPosition) + newException.setStackTrace(getStackTrace) + newException + } + + override def getMessage: String = { + val planAnnotation = Option(plan).flatten.map(p => s";\n$p").getOrElse("") + getSimpleMessage + planAnnotation + } + + // Outputs an exception without the logical plan. + // For testing only + def getSimpleMessage: String = { + val lineAnnotation = line.map(l => s" line $l").getOrElse("") + val positionAnnotation = startPosition.map(p => s" pos $p").getOrElse("") + s"$message;$lineAnnotation$positionAnnotation" + } +} \ No newline at end of file diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/HiveAcidUtils.scala b/src/main/scala/com/qubole/spark/hiveacid/HiveAcidOperation.scala similarity index 81% rename from src/main/scala/com/qubole/spark/datasources/hiveacid/HiveAcidUtils.scala rename to src/main/scala/com/qubole/spark/hiveacid/HiveAcidOperation.scala index 8161896..13c8e49 100644 --- a/src/main/scala/com/qubole/spark/datasources/hiveacid/HiveAcidUtils.scala +++ b/src/main/scala/com/qubole/spark/hiveacid/HiveAcidOperation.scala @@ -17,8 +17,9 @@ * limitations under the License. */ -package com.qubole.spark.datasources.hiveacid +package com.qubole.spark.hiveacid -object HiveAcidUtils { - val NAME = "HiveAcid" +private[hiveacid] object HiveAcidOperation extends Enumeration { + type OperationType = Value + val READ, INSERT_INTO, INSERT_OVERWRITE, DELETE, UPDATE = Value } diff --git a/src/main/scala/com/qubole/spark/hiveacid/HiveAcidTable.scala b/src/main/scala/com/qubole/spark/hiveacid/HiveAcidTable.scala new file mode 100644 index 0000000..bc92ddf --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/HiveAcidTable.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid + +import com.qubole.spark.hiveacid.reader.TableReader +import com.qubole.spark.hiveacid.writer.TableWriter +import com.qubole.spark.hiveacid.hive.HiveAcidMetadata +import com.qubole.spark.hiveacid.datasource.HiveAcidDataSource +import com.qubole.spark.hiveacid.rdd.EmptyRDD +import com.qubole.spark.hiveacid.transaction._ + +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, _} +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.Column +import org.apache.spark.sql.SqlUtils +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.datasources.LogicalRelation + + +/** + * Represents a hive acid table and exposes API to perform operations on top of it + * @param sparkSession - spark session object + * @param hiveAcidMetadata - metadata object + * @param parameters - additional parameters + */ +class HiveAcidTable(sparkSession: SparkSession, + hiveAcidMetadata: HiveAcidMetadata, + parameters: Map[String, String] + ) extends Logging { + + private var isLocalTxn: Boolean = false + private var curTxn: HiveAcidTxn = _ + + // Start local transaction if not passed. + private def getOrCreateTxn(): Unit = { + curTxn = HiveAcidTxn.currentTxn() + curTxn match { + case null => + // create local txn + curTxn = HiveAcidTxn.createTransaction(sparkSession) + curTxn.begin() + isLocalTxn = true + case txn => + logDebug(s"Existing Transactions $txn") + } + } + + // End and reset transaction and snapshot + // if locally started + private def unsetOrEndTxn(abort: Boolean = false): Unit = { + if (! isLocalTxn) { + return + } + curTxn.end(abort) + curTxn = null + isLocalTxn = false + } + + // Start and end transaction under protection. + private def inTxn(f: => Unit): Unit = synchronized { + getOrCreateTxn() + var abort = false + try { f } + catch { + case e: Exception => + logError("Unable to execute in transactions due to: " + e.getMessage) + abort = true; + } + finally { + unsetOrEndTxn(abort) + } + } + + /** + * Create dataframe to read based on hiveAcidTable and passed in filter. + * @return Dataframe + */ + private def readDF: DataFrame = { + // Fetch row with rowID in it + sparkSession.read.format(HiveAcidDataSource.NAME) + .options(parameters ++ + Map("includeRowIds" -> "true", "table" -> hiveAcidMetadata.fullyQualifiedName)) + .load() + } + + /** + * Return df after after applying update clause and filter clause. This df is used to + * update the table. + * @param condition - condition string to identify rows which needs to be updated + * @param newValues - Map of (column, value) to set + */ + private def updateDF(condition: String, newValues: Map[String, String]): DataFrame = { + + val df= readDF + + val plan = df.queryExecution.analyzed + val qualifiedPlan = plan match { + case p: LogicalRelation => + p.copy(output = p.output + .map((x: AttributeReference) => + x.withQualifier(hiveAcidMetadata.fullyQualifiedName.split('.').toSeq)) + ) + case _ => plan + } + val resolvedExpr = SqlUtils.resolveReferences(sparkSession, + functions.expr(condition).expr, + qualifiedPlan) + + val newDf = SqlUtils.convertToDF(sparkSession, qualifiedPlan) + + def toStrColumnMap(map: Map[String, String]): Map[String, Column] = { + map.toSeq.map { case (k, v) => + k -> functions.expr(SqlUtils.resolveReferences(sparkSession, functions.expr(v).expr, + qualifiedPlan).sql)}.toMap + } + + val strColumnMap = toStrColumnMap(newValues) + val updateExpressions: Seq[Expression] = + newDf.queryExecution.optimizedPlan.output.map { + attr => + if (strColumnMap.contains(attr.name)) { + strColumnMap(attr.name).expr + } else { + attr + } + } + + val newColumns = updateExpressions.zip(df.queryExecution.optimizedPlan.output).map { + case (newExpr, origAttr) => + new Column(Alias(newExpr, origAttr.name)()) + } + + newDf.filter(resolvedExpr.sql).select(newColumns: _*) + } + + /** + * Return an RDD on top of Hive ACID table + * @param requiredColumns - columns needed + * @param filters - filters that can be pushed down to file format + * @param readConf - read conf + * @return + */ + def getRdd(requiredColumns: Array[String], + filters: Array[Filter], + readConf: ReadConf): RDD[Row] = { + var res: RDD[Row] = new EmptyRDD[Row](sparkSession.sparkContext) + + // TODO: Read does not perform read but returns an RDD, which materializes + // outside this function. For transactional guarantees, the transaction + // boundary needs to span getRDD call. Currently we return the RDD + // without any protection. + inTxn { + val tableReader = new TableReader(sparkSession, curTxn, hiveAcidMetadata) + res = tableReader.getRdd(requiredColumns, filters, readConf) + } + res + } + + /** + * Appends a given dataframe df into the hive acid table + * @param df - dataframe to insert + */ + def insertInto(df: DataFrame): Unit = inTxn { + val tableWriter = new TableWriter(sparkSession, curTxn, hiveAcidMetadata) + tableWriter.process(HiveAcidOperation.INSERT_INTO, df) + } + + /** + * Overwrites a given dataframe df onto the hive acid table + * @param df - dataframe to insert + */ + def insertOverwrite(df: DataFrame): Unit = inTxn { + val tableWriter = new TableWriter(sparkSession, curTxn, hiveAcidMetadata) + tableWriter.process(HiveAcidOperation.INSERT_OVERWRITE, df) + } + + /** + * Delete rows from the table based on condtional expression. + * @param condition - Conditional filter for delete + */ + def delete(condition: String): Unit = inTxn { + val df = readDF + val resolvedExpr= SqlUtils.resolveReferences(sparkSession, + functions.expr(condition).expr, + df.queryExecution.analyzed) + val tableWriter = new TableWriter(sparkSession, curTxn, hiveAcidMetadata) + tableWriter.process(HiveAcidOperation.DELETE, df.filter(resolvedExpr.sql)) + } + + /** + * Update rows in the hive acid table based on condition and newValues + * @param condition - condition string to identify rows which needs to be updated + * @param newValues - Map of (column, value) to set + */ + def update(condition: String, newValues: Map[String, String]): Unit = inTxn { + val updateDf = updateDF(condition, newValues) + val tableWriter = new TableWriter(sparkSession, curTxn, hiveAcidMetadata) + tableWriter.process(HiveAcidOperation.UPDATE, updateDf) + } +} + +object HiveAcidTable { + def fromSparkSession(sparkSession: SparkSession, + fullyQualifiedTableName: String, + parameters: Map[String, String] = Map() + ): HiveAcidTable = { + + val hiveAcidMetadata: HiveAcidMetadata = + HiveAcidMetadata.fromSparkSession(sparkSession, fullyQualifiedTableName) + new HiveAcidTable(sparkSession, hiveAcidMetadata, parameters) + } +} diff --git a/src/main/scala/com/qubole/spark/hiveacid/ReadConf.scala b/src/main/scala/com/qubole/spark/hiveacid/ReadConf.scala new file mode 100644 index 0000000..c3582c1 --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/ReadConf.scala @@ -0,0 +1,47 @@ +/* + * Copyright 2019 Qubole, Inc. All rights reserved. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid + +import org.apache.spark.sql.SparkSession + +/** + * Spark specific configuration container to be used by Hive Acid module + */ +case class ReadConf(predicatePushdownEnabled: Boolean = true, + metastorePartitionPruningEnabled: Boolean = true, + var includeRowIds: Boolean = false) + +object ReadConf { + + val PREDICATE_PUSHDOWN_CONF = "spark.sql.hiveAcid.enablePredicatePushdown" + + def build(sparkSession: SparkSession, parameters: Map[String, String]): ReadConf = { + val isPredicatePushdownEnabled: Boolean = { + val sqlConf = sparkSession.sessionState.conf + sqlConf.getConfString(PREDICATE_PUSHDOWN_CONF, "true") == "true" + } + new ReadConf( + isPredicatePushdownEnabled, + sparkSession.sessionState.conf.metastorePartitionPruning, + parameters.getOrElse("includeRowIds", "false").toBoolean + ) + } +} + diff --git a/src/main/scala/com/qubole/spark/hiveacid/datasource/HiveAcidDataSource.scala b/src/main/scala/com/qubole/spark/hiveacid/datasource/HiveAcidDataSource.scala new file mode 100644 index 0000000..47795a6 --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/datasource/HiveAcidDataSource.scala @@ -0,0 +1,80 @@ +/* + * Copyright 2019 Qubole, Inc. All rights reserved. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid.datasource + +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.sources._ + +import com.qubole.spark.hiveacid.{HiveAcidTable, HiveAcidErrors} + +/** + * HiveAcid Data source implementation. + */ +class HiveAcidDataSource + extends RelationProvider // USING HiveAcid + with CreatableRelationProvider // Insert into/overwrite + with DataSourceRegister // FORMAT("HiveAcid") + with Logging { + + // returns relation for passed in table name + override def createRelation(sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + HiveAcidRelation(sqlContext.sparkSession, getFullyQualifiedTableName(parameters), parameters) + } + + // returns relation after writing passed in data frame. Table name is part of parameter + override def createRelation(sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + df: DataFrame): BaseRelation = { + + val hiveAcidTable: HiveAcidTable = HiveAcidTable.fromSparkSession( + sqlContext.sparkSession, + getFullyQualifiedTableName(parameters), + parameters) + + mode match { + case SaveMode.Overwrite => + hiveAcidTable.insertOverwrite(df) + case SaveMode.Append => + hiveAcidTable.insertInto(df) + // TODO: Add support for these + case SaveMode.ErrorIfExists | SaveMode.Ignore => + HiveAcidErrors.unsupportedSaveMode(mode) + } + createRelation(sqlContext, parameters) + } + + override def shortName(): String = { + HiveAcidDataSource.NAME + } + + private def getFullyQualifiedTableName(parameters: Map[String, String]): String = { + parameters.getOrElse("table", { + throw HiveAcidErrors.tableNotSpecifiedException() + }) + } +} + +object HiveAcidDataSource { + val NAME = "HiveAcid" +} + diff --git a/src/main/scala/com/qubole/spark/hiveacid/datasource/HiveAcidRelation.scala b/src/main/scala/com/qubole/spark/hiveacid/datasource/HiveAcidRelation.scala new file mode 100644 index 0000000..4dde0ef --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/datasource/HiveAcidRelation.scala @@ -0,0 +1,89 @@ +/* + * Copyright 2019 Qubole, Inc. All rights reserved. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid.datasource + +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SparkSession, SQLContext} +import org.apache.spark.sql.sources.{BaseRelation, Filter, InsertableRelation, PrunedFilteredScan} +import org.apache.spark.sql.types._ + +import com.qubole.spark.hiveacid.{HiveAcidTable, ReadConf} +import com.qubole.spark.hiveacid.hive.HiveAcidMetadata + +/** + * Container for all metadata, configuration and schema to perform operations on + * Hive ACID datasource. This provides for plumbing most of the heavy lifting is + * performed inside HiveAcidtTable. + * + * @param sparkSession Spark Session object + * @param fullyQualifiedTableName Table name for the data source. + * @param parameters user provided parameters required for reading and writing, + * including configuration + */ +case class HiveAcidRelation(sparkSession: SparkSession, + fullyQualifiedTableName: String, + parameters: Map[String, String]) + extends BaseRelation + with InsertableRelation + with PrunedFilteredScan + with Logging { + + private val hiveAcidMetadata: HiveAcidMetadata = HiveAcidMetadata.fromSparkSession( + sparkSession, + fullyQualifiedTableName + ) + private val hiveAcidTable: HiveAcidTable = new HiveAcidTable(sparkSession, + hiveAcidMetadata, parameters) + + private val readOptions = ReadConf.build(sparkSession, parameters) + + override def sqlContext: SQLContext = sparkSession.sqlContext + + override val schema: StructType = if (readOptions.includeRowIds) { + hiveAcidMetadata.tableSchemaWithRowId + } else { + hiveAcidMetadata.tableSchema + } + + override def insert(data: DataFrame, overwrite: Boolean): Unit = { + // sql insert into and overwrite + if (overwrite) { + hiveAcidTable.insertOverwrite(data) + } else { + hiveAcidTable.insertInto(data) + } + } + + override def sizeInBytes: Long = { + val compressionFactor = sparkSession.sessionState.conf.fileCompressionFactor + (sparkSession.sessionState.conf.defaultSizeInBytes * compressionFactor).toLong + } + + // FIXME: should it be true / false. Recommendation seems to + // be to leave it as true + override val needConversion: Boolean = false + + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { + val readOptions = ReadConf.build(sparkSession, parameters) + // sql "select *" + hiveAcidTable.getRdd(requiredColumns, filters, readOptions) + } +} diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/orc/.gitignore b/src/main/scala/com/qubole/spark/hiveacid/hive/.gitignore similarity index 100% rename from src/main/scala/com/qubole/spark/datasources/hiveacid/orc/.gitignore rename to src/main/scala/com/qubole/spark/hiveacid/hive/.gitignore diff --git a/src/main/scala/com/qubole/spark/hiveacid/hive/HiveAcidMetadata.scala b/src/main/scala/com/qubole/spark/hiveacid/hive/HiveAcidMetadata.scala new file mode 100644 index 0000000..eb17fbb --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/hive/HiveAcidMetadata.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid.hive + +import java.util.Locale + +import scala.collection.JavaConversions._ +import scala.collection.mutable + +import com.qubole.shaded.hadoop.hive.conf.HiveConf +import com.qubole.shaded.hadoop.hive.ql.io.RecordIdentifier +import com.qubole.shaded.hadoop.hive.ql.metadata +import com.qubole.shaded.hadoop.hive.ql.metadata.Hive +import com.qubole.shaded.hadoop.hive.ql.plan.TableDesc +import com.qubole.spark.hiveacid.util.Util +import com.qubole.spark.hiveacid.HiveAcidErrors +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred.{InputFormat, OutputFormat} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.types._ + +/** + * Represents metadata for hive acid table and exposes API to perform operations on top of it + * @param sparkSession - spark session object + * @param fullyQualifiedTableName - the fully qualified hive acid table name + */ +class HiveAcidMetadata(sparkSession: SparkSession, + fullyQualifiedTableName: String) extends Logging { + + // hive conf + private val hiveConf: HiveConf = HiveConverter.getHiveConf(sparkSession.sparkContext) + + // a hive representation of the table + val hTable: metadata.Table = { + val hive: Hive = Hive.get(hiveConf) + val table = sparkSession.sessionState.sqlParser.parseTableIdentifier(fullyQualifiedTableName) + val hTable = hive.getTable( + table.database match { + case Some(database) => database + case None => HiveAcidMetadata.DEFAULT_DATABASE + }, table.identifier) + Hive.closeCurrent() + hTable + } + + if (hTable.getParameters.get("transactional") != "true") { + throw HiveAcidErrors.tableNotAcidException(hTable.getFullyQualifiedName) + } + + val isFullAcidTable: Boolean = hTable.getParameters.containsKey("transactional_properties") && + !hTable.getParameters.get("transactional_properties").equals("insert_only") + val isInsertOnlyTable: Boolean = !isFullAcidTable + + // Table properties + val isPartitioned: Boolean = hTable.isPartitioned + val rootPath: Path = hTable.getDataLocation + val dbName: String = hTable.getDbName + val tableName: String = hTable.getTableName + val fullyQualifiedName: String = hTable.getFullyQualifiedName + + // Schema properties + val dataSchema = StructType(hTable.getSd.getCols.toList.map( + HiveConverter.getCatalystStructField).toArray) + + val partitionSchema = StructType(hTable.getPartitionKeys.toList.map( + HiveConverter.getCatalystStructField).toArray) + + val rowIdSchema: StructType = { + StructType( + RecordIdentifier.Field.values().map { + field => + StructField( + name = field.name(), + dataType = HiveConverter.getCatalystType(field.fieldType.getTypeName), + nullable = true) + } + ) + } + + val tableSchema: StructType = { + val overlappedPartCols = mutable.Map.empty[String, StructField] + partitionSchema.foreach { partitionField => + if (dataSchema.exists(getColName(_) == getColName(partitionField))) { + overlappedPartCols += getColName(partitionField) -> partitionField + } + } + StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ + partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) + } + + val tableSchemaWithRowId: StructType = { + StructType( + Seq( + StructField("rowId", rowIdSchema) + ) ++ tableSchema.fields) + } + + lazy val tableDesc: TableDesc = { + val inputFormatClass: Class[InputFormat[Writable, Writable]] = + Util.classForName(hTable.getInputFormatClass.getName, + loadShaded = true).asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] + val outputFormatClass: Class[OutputFormat[Writable, Writable]] = + Util.classForName(hTable.getOutputFormatClass.getName, + loadShaded = true).asInstanceOf[java.lang.Class[OutputFormat[Writable, Writable]]] + new TableDesc( + inputFormatClass, + outputFormatClass, + hTable.getMetadata) + } + + /** + * Returns list of partitions satisfying partition predicates + * @param partitionFilters - filters to apply + */ + def getRawPartitions(partitionFilters: Option[String] = None): Seq[metadata.Partition] = { + val hive: Hive = Hive.get(hiveConf) + val prunedPartitions = partitionFilters match { + case Some(filter) => hive.getPartitionsByFilter(hTable, filter) + case None => hive.getPartitions(hTable) + } + Hive.closeCurrent() + + logDebug(s"partition count = ${prunedPartitions.size()}") + prunedPartitions.toSeq + } + + private def getColName(field: StructField): String = { + if (sparkSession.sessionState.conf.caseSensitiveAnalysis) { + field.name + } else { + field.name.toLowerCase(Locale.ROOT) + } + } +} + +object HiveAcidMetadata { + val DEFAULT_DATABASE = "default" + + def fromSparkSession(sparkSession: SparkSession, + fullyQualifiedTableName: String): HiveAcidMetadata = { + new HiveAcidMetadata( + sparkSession, + fullyQualifiedTableName) + } +} diff --git a/src/main/scala/com/qubole/spark/hiveacid/hive/HiveConverter.scala b/src/main/scala/com/qubole/spark/hiveacid/hive/HiveConverter.scala new file mode 100644 index 0000000..3918d45 --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/hive/HiveConverter.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid.hive + +import java.util.Locale + +import com.qubole.shaded.hadoop.hive.conf.HiveConf +import com.qubole.shaded.hadoop.hive.metastore.api.FieldSchema +import org.apache.commons.lang3.StringUtils +import org.apache.spark.internal.Logging +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ + +import scala.collection.JavaConversions._ + +/** + * Encapsulates everything (extensions, workarounds, quirks) to handle the + * SQL dialect conversion between catalyst and hive. + */ +private[hiveacid] object HiveConverter extends Logging { + + def getCatalystStructField(hc: FieldSchema): StructField = { + val columnType = getCatalystType(hc.getType) + val metadata = if (hc.getType != columnType.catalogString) { + new MetadataBuilder().putString(HIVE_TYPE_STRING, hc.getType).build() + } else { + Metadata.empty + } + + val field = StructField( + name = hc.getName, + dataType = columnType, + nullable = true, + metadata = metadata) + Option(hc.getComment).map(field.withComment).getOrElse(field) + } + + def getCatalystType(dataType: String): DataType = { + try { + CatalystSqlParser.parseDataType(dataType) + } catch { + case e: ParseException => + throw new SparkException("Cannot recognize hive type string: " + dataType, e) + } + } + + def getHiveConf(sparkContext: SparkContext): HiveConf = { + val hiveConf = new HiveConf() + (sparkContext.hadoopConfiguration.iterator().map(kv => kv.getKey -> kv.getValue) + ++ sparkContext.getConf.getAll.toMap).foreach { case (k, v) => + logDebug( + s""" + |Applying Hadoop/Hive/Spark and extra properties to Hive Conf: + |$k=${if (k.toLowerCase(Locale.ROOT).contains("password")) "xxx" else v} + """.stripMargin) + hiveConf.set(k, v) + } + hiveConf + } + + /** + * Escape special characters in SQL string literals. + * + * @param value The string to be escaped. + * @return Escaped string. + */ + private def escapeSql(value: String): String = { + // TODO: how to handle null + StringUtils.replace(value, "'", "''") + } + + /** + * Converts value to SQL expression. + * @param value The value to be converted. + * @return Converted value. + */ + private def compileValue(value: Any): Any = value match { + case stringValue: String => s"'${escapeSql(stringValue)}'" + case _ => value + } + + /** + * Turns a single Filter into a String representing a SQL expression. + * Returns None for an unhandled filter. + */ + def compileFilter(f: Filter): Option[String] = Option(x = f match { + case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" + case EqualNullSafe(attr, value) => + val col = attr + s"(NOT ($col != ${compileValue(value)} OR $col = 'NULL' OR " + + s"${compileValue(value)} = 'NULL') OR " + + s"($col = 'NULL' AND ${compileValue(value)} = 'NULL'))" + case LessThan(attr, value) => s"$attr < ${compileValue(value)}" + case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" + case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" + case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" + case IsNull(attr) => s"$attr = 'NULL'" + case IsNotNull(attr) => s"$attr != 'NULL'" + case StringStartsWith(attr, value) => s"$attr LIKE '$value%'" + case StringEndsWith(attr, value) => s"$attr LIKE '%$value'" + case StringContains(attr, value) => s"$attr LIKE '%$value%'" + case In(attr, value) => s"$attr IN (${compileValue(value)})" + case Not(`f`) => compileFilter(f).map(p => s"(NOT ($p))").orNull + case Or(f1, f2) => + // We can't compile Or filter unless both sub-filters are compiled successfully. + // It applies too for the following And filter. + // If we can make sure compileFilter supports all filters, we can remove this check. + val or = Seq(f1, f2) flatMap compileFilter + if (or.size == 2) { + or.map(p => s"($p)").mkString(" OR ") + } else null + case And(f1, f2) => + val and = Seq(f1, f2).flatMap(compileFilter) + if (and.size == 2) { + and.map(p => s"($p)").mkString(" AND ") + } else null + case _ => null + }) + + + def compileFilters(filters: Seq[Filter]): String = { + val str = filters.flatMap(compileFilter).mkString(" and ") + logDebug(str) + str + } +} diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/util/EmptyRDD.scala b/src/main/scala/com/qubole/spark/hiveacid/rdd/EmptyRDD.scala similarity index 90% rename from src/main/scala/com/qubole/spark/datasources/hiveacid/util/EmptyRDD.scala rename to src/main/scala/com/qubole/spark/hiveacid/rdd/EmptyRDD.scala index 0894aae..625fb7c 100644 --- a/src/main/scala/com/qubole/spark/datasources/hiveacid/util/EmptyRDD.scala +++ b/src/main/scala/com/qubole/spark/hiveacid/rdd/EmptyRDD.scala @@ -17,14 +17,14 @@ * limitations under the License. */ -package com.qubole.spark.datasources.hiveacid.util - -import org.apache.spark.rdd.RDD -import org.apache.spark.{Partition, SparkContext, TaskContext} +package com.qubole.spark.hiveacid.rdd import scala.reflect.ClassTag -class EmptyRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc, Nil) { +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.rdd.RDD + +private[hiveacid] class EmptyRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc, Nil) { override def getPartitions: Array[Partition] = Array.empty diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/rdd/Hive3Rdd.scala b/src/main/scala/com/qubole/spark/hiveacid/rdd/HiveAcidRDD.scala similarity index 58% rename from src/main/scala/com/qubole/spark/datasources/hiveacid/rdd/Hive3Rdd.scala rename to src/main/scala/com/qubole/spark/hiveacid/rdd/HiveAcidRDD.scala index c2155c7..35008dd 100644 --- a/src/main/scala/com/qubole/spark/datasources/hiveacid/rdd/Hive3Rdd.scala +++ b/src/main/scala/com/qubole/spark/hiveacid/rdd/HiveAcidRDD.scala @@ -17,44 +17,41 @@ * limitations under the License. */ -package com.qubole.spark.datasources.hiveacid.rdd +package com.qubole.spark.hiveacid.rdd import java.io.{FileNotFoundException, IOException} import java.text.SimpleDateFormat import java.util.concurrent.ConcurrentHashMap import java.util.{Date, Locale} +import scala.collection.JavaConversions._ +import scala.collection.mutable.ListBuffer +import scala.reflect.ClassTag import com.qubole.shaded.hadoop.hive.common.ValidWriteIdList -import com.qubole.shaded.hadoop.hive.ql.io.{AcidUtils, HiveInputFormat} -import com.qubole.spark.datasources.hiveacid.HiveAcidState -import com.qubole.spark.datasources.hiveacid.util.{InputFileBlockHolder, NextIterator, SerializableConfiguration, Util} -import com.qubole.spark.datasources.hiveacid.rdd.Hive3RDD.Hive3PartitionsWithSplitRDD -import com.qubole.spark.datasources.hiveacid.util.{SerializableWritable => _, _} +import com.qubole.shaded.hadoop.hive.ql.io.{AcidInputFormat, AcidUtils, HiveInputFormat, RecordIdentifier} +import com.qubole.spark.hiveacid.rdd.HiveAcidRDD.HiveAcidPartitionsWithSplitRDD +import com.qubole.spark.hiveacid.util.{SerializableConfiguration, Util} +import com.qubole.spark.hiveacid.util.{SerializableWritable => _} import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapred.lib.CombineFileSplit import org.apache.hadoop.mapred.{FileInputFormat, _} import org.apache.hadoop.mapreduce.TaskType import org.apache.hadoop.util.ReflectionUtils -import org.apache.spark._ +import org.apache.spark.{Partitioner, _} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.InputMetrics import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import scala.collection.JavaConversions._ -import scala.collection.mutable.ListBuffer -import scala.reflect.ClassTag +// This file has lot of borrowed code from org.apache.spark.rdd.HadoopRdd -object Cache { - import com.google.common.collect.MapMaker +private object Cache { val jobConf = new ConcurrentHashMap[String, Any]() } -class Hive3Partition(rddId: Int, override val index: Int, s: InputSplit) +private class HiveAcidPartition(rddId: Int, override val index: Int, s: InputSplit) extends Partition { val inputSplit = new SerializableWritable[InputSplit](s) @@ -65,39 +62,40 @@ class Hive3Partition(rddId: Int, override val index: Int, s: InputSplit) } /** - * :: DeveloperApi :: - * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, - * sources in HBase, or S3), using the older MapReduce API (`org.apache.hadoop.mapred`). - * - * @param sc The SparkContext to associate the RDD with. - * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed - * variable references an instance of JobConf, then that JobConf will be used for the Hadoop job. - * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration. - * @param initLocalJobConfFuncOpt Optional closure used to initialize any JobConf that Hive3RDD - * creates. - * @param inputFormatClass Storage format of the data to be read. - * @param keyClass Class of the key associated with the inputFormatClass. - * @param valueClass Class of the value associated with the inputFormatClass. - * @param minPartitions Minimum number of Hive3RDD partitions (Hadoop Splits) to generate. - * - * @note Instantiating this class directly is not recommended, please use - * `org.apache.spark.SparkContext.Hive3RDD()` - */ -@DeveloperApi -class Hive3RDD[K, V]( - sc: SparkContext, - @transient val acidState: HiveAcidState, - broadcastedConf: Broadcast[SerializableConfiguration], - initLocalJobConfFuncOpt: Option[JobConf => Unit], - inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], - valueClass: Class[V], - minPartitions: Int) - extends RDD[(K, V)](sc, Nil) with Logging { + * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, + * sources in HBase, or S3), using the older MapReduce API (`org.apache.hadoop.mapred`). + * + * @param sc The SparkContext to associate the RDD with. + * @param validWriteIds The list of valid write ids. + * @param isFullAcidTable if table is full acid table. + * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed + * variable references an instance of JobConf, then that JobConf will be used for the Hadoop job. + * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration. + * @param initLocalJobConfFuncOpt Optional closure used to initialize any JobConf that HiveAcidRDD + * creates. + * @param inputFormatClass Storage format of the data to be read. + * @param keyClass Class of the key associated with the inputFormatClass. + * @param valueClass Class of the value associated with the inputFormatClass. + * @param minPartitions Minimum number of HiveAcidRDD partitions (Hadoop Splits) to generate. + * + * @note Instantiating this class directly is not recommended, please use + * `org.apache.spark.SparkContext.HiveAcidRDD()` + */ +private[hiveacid] class HiveAcidRDD[K, V](sc: SparkContext, + @transient val validWriteIds: ValidWriteIdList, + @transient val isFullAcidTable: Boolean, + broadcastedConf: Broadcast[SerializableConfiguration], + initLocalJobConfFuncOpt: Option[JobConf => Unit], + inputFormatClass: Class[_ <: InputFormat[K, V]], + keyClass: Class[K], + valueClass: Class[V], + minPartitions: Int) + extends RDD[(RecordIdentifier, V)](sc, Nil) with Logging { def this( sc: SparkContext, - @transient acidState: HiveAcidState, + @transient validWriteIds: ValidWriteIdList, + @transient isFullAcidTable: Boolean, conf: JobConf, inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], @@ -105,7 +103,8 @@ class Hive3RDD[K, V]( minPartitions: Int) = { this( sc, - acidState, + validWriteIds, + isFullAcidTable, sc.broadcast(new SerializableConfiguration(conf)) .asInstanceOf[Broadcast[SerializableConfiguration]], initLocalJobConfFuncOpt = None, @@ -123,19 +122,19 @@ class Hive3RDD[K, V]( private val createTime = new Date() private val shouldCloneJobConf = - sparkContext.getConf.getBoolean("spark.hadoop.cloneConf", false) + sparkContext.getConf.getBoolean("spark.hadoop.cloneConf", defaultValue = false) private val ignoreCorruptFiles = - sparkContext.getConf.getBoolean("spark.files.ignoreCorruptFiles", false) + sparkContext.getConf.getBoolean("spark.files.ignoreCorruptFiles", defaultValue = false) private val ignoreMissingFiles = - sparkContext.getConf.getBoolean("spark.files.ignoreMissingFiles", false) + sparkContext.getConf.getBoolean("spark.files.ignoreMissingFiles", defaultValue = false) private val ignoreEmptySplits = - sparkContext.getConf.getBoolean("spark.hadoopRDD.ignoreEmptySplits", false) + sparkContext.getConf.getBoolean("spark.hadoopRDD.ignoreEmptySplits", defaultValue = false) // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. - protected def getJobConf(): JobConf = { + protected def getJobConf: JobConf = { val conf: Configuration = broadcastedConf.value.value if (shouldCloneJobConf) { // Hadoop Configuration objects are not thread-safe, which may lead to various problems if @@ -145,7 +144,7 @@ class Hive3RDD[K, V]( // clone can be very expensive. To avoid unexpected performance regressions for workloads and // Hadoop versions that do not suffer from these thread-safety issues, this cloning is // disabled by default. - Hive3RDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + HiveAcidRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { logDebug("Cloning Hadoop Configuration") val newJobConf = new JobConf(conf) if (!conf.isInstanceOf[JobConf]) { @@ -154,29 +153,30 @@ class Hive3RDD[K, V]( newJobConf } } else { - if (conf.isInstanceOf[JobConf]) { - logDebug("Re-using user-broadcasted JobConf") - conf.asInstanceOf[JobConf] - } else { - Option(Hive3RDD.getCachedMetadata(jobConfCacheKey)) - .map { conf => - logDebug("Re-using cached JobConf") - conf.asInstanceOf[JobConf] - } - .getOrElse { - // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in - // the local process. The local cache is accessed through Hive3RDD.putCachedMetadata(). - // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary - // objects. Synchronize to prevent ConcurrentModificationException (SPARK-1097, - // HADOOP-10456). - Hive3RDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { - logDebug("Creating new JobConf and caching it for later re-use") - val newJobConf = new JobConf(conf) - initLocalJobConfFuncOpt.foreach(f => f(newJobConf)) - Hive3RDD.putCachedMetadata(jobConfCacheKey, newJobConf) - newJobConf + conf match { + case c: JobConf => + logDebug("Re-using user-broadcasted JobConf") + c + case _ => + Option(HiveAcidRDD.getCachedMetadata(jobConfCacheKey)) + .map { conf => + logDebug("Re-using cached JobConf") + conf.asInstanceOf[JobConf] + } + .getOrElse { + // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in + // the local process. The local cache is accessed through HiveAcidRDD.putCachedMetadata(). + // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary + // objects. Synchronize to prevent ConcurrentModificationException (SPARK-1097, + // HADOOP-10456). + HiveAcidRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Creating new JobConf and caching it for later re-use") + val newJobConf = new JobConf(conf) + initLocalJobConfFuncOpt.foreach(f => f(newJobConf)) + HiveAcidRDD.putCachedMetadata(jobConfCacheKey, newJobConf) + newJobConf + } } - } } } } @@ -192,12 +192,11 @@ class Hive3RDD[K, V]( } override def getPartitions: Array[Partition] = { - val validWriteIds: ValidWriteIdList = acidState.getValidWriteIds - //val ValidWriteIdList = acidState.getValidWriteIdsNoTxn - var jobConf = getJobConf() + var jobConf = getJobConf - if (acidState.isFullAcidTable) { - // If full ACID table, just set the right writeIds, the OrcInputFormat.getSplits() will take care of the rest + if (isFullAcidTable) { + // If full ACID table, just set the right writeIds, the + // OrcInputFormat.getSplits() will take care of the rest AcidUtils.setValidWriteIdList(jobConf, validWriteIds) } else { val finalPaths = new ListBuffer[Path]() @@ -238,62 +237,39 @@ class Hive3RDD[K, V]( } val array = new Array[Partition](inputSplits.size) for (i <- 0 until inputSplits.size) { - array(i) = new Hive3Partition(id, i, inputSplits(i)) + array(i) = new HiveAcidPartition(id, i, inputSplits(i)) } array } catch { case e: InvalidInputException if ignoreMissingFiles => - logWarning(s"${jobConf.get(org.apache.hadoop.mapreduce.lib.input.FileInputFormat.INPUT_DIR)} doesn't exist and no" + + val inputDir = jobConf.get(org.apache.hadoop.mapreduce.lib.input.FileInputFormat.INPUT_DIR) + logWarning(s"$inputDir doesn't exist and no" + s" partitions returned from this path.", e) Array.empty[Partition] } } - override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { - val iter: NextIterator[(K, V)] = new NextIterator[(K, V)] { - - private val split = theSplit.asInstanceOf[Hive3Partition] - logInfo("Input split: " + split.inputSplit) - val scheme = Util.getSplitScheme(split.inputSplit.value) - val jobConf = getJobConf() + override def compute(theSplit: Partition, + context: TaskContext): InterruptibleIterator[(RecordIdentifier, V)] = { + val iter: NextIterator[(RecordIdentifier, V)] = new NextIterator[(RecordIdentifier, V)] { - val inputMetrics = context.taskMetrics().inputMetrics - val existingBytesRead = inputMetrics.bytesRead - val blobStoreInputMetrics: Option[InputMetrics] = None - val existingBlobStoreBytesRead = blobStoreInputMetrics.map(_.bytesRead).sum - - // Sets InputFileBlockHolder for the file block's information - split.inputSplit.value match { - case fs: FileSplit => - InputFileBlockHolder.set(fs.getPath.toString, fs.getStart, fs.getLength) - case _ => - InputFileBlockHolder.unset() - } - - // Find a function that will return the FileSystem bytes read by this thread. Do this before - // creating RecordReader, because RecordReader's constructor might read some bytes - private val getBytesReadCallback: Option[() => Long] = None - - // We get our input bytes from thread-local Hadoop FileSystem statistics. - // If we do a coalesce, however, we are likely to compute multiple partitions in the same - // task and in the same thread, in which case we need to avoid override values written by - // previous partitions (SPARK-13071). - private def updateBytesRead(): Unit = { - getBytesReadCallback.foreach { getBytesRead => -// inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) -// blobStoreInputMetrics.foreach(_.setBytesRead(existingBlobStoreBytesRead + getBytesRead())) - } - } + private val split = theSplit.asInstanceOf[HiveAcidPartition] + logDebug("Input split: " + split.inputSplit) + val jobConf: JobConf = getJobConf - private var reader: RecordReader[K, V] = null + private var reader: RecordReader[K, V] = _ private val inputFormat = getInputFormat(jobConf) - Hive3RDD.addLocalConfiguration( + HiveAcidRDD.addLocalConfiguration( new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(createTime), context.stageId, theSplit.index, context.attemptNumber, jobConf) reader = try { - inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) + // Underlying code is not MT safe. Synchronize + // while creating record reader + HiveAcidRDD.RECORD_READER_INIT_LOCK.synchronized { + inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) + } } catch { case e: FileNotFoundException if ignoreMissingFiles => logWarning(s"Skipped missing file: ${split.inputSplit}", e) @@ -307,19 +283,26 @@ class Hive3RDD[K, V]( null } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener[Unit] { context => - // Update the bytes read before closing is to make sure lingering bytesRead statistics in - // this thread get correctly added. - updateBytesRead() + context.addTaskCompletionListener[Unit] { _ => closeIfNeeded() } private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey() private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue() + private var recordIdentifier: RecordIdentifier = _ + private val acidRecordReader = reader match { + case acidReader: AcidInputFormat.AcidRecordReader[_, _] => + acidReader + case _ => + null + } - override def getNext(): (K, V) = { + override def getNext(): (RecordIdentifier, V) = { try { finished = !reader.next(key, value) + if (!finished && acidRecordReader != null) { + recordIdentifier = acidRecordReader.getRecordIdentifier + } } catch { case e: FileNotFoundException if ignoreMissingFiles => logWarning(s"Skipped missing file: ${split.inputSplit}", e) @@ -330,19 +313,11 @@ class Hive3RDD[K, V]( logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) finished = true } - if (!finished) { -// inputMetrics.incRecordsRead(1) -// blobStoreInputMetrics.foreach(_.incRecordsRead(1)) - } -// if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { -// updateBytesRead() -// } - (key, value) + (recordIdentifier, value) } override def close(): Unit = { if (reader != null) { - InputFileBlockHolder.unset() try { reader.close() } catch { @@ -353,24 +328,10 @@ class Hive3RDD[K, V]( } finally { reader = null } - if (getBytesReadCallback.isDefined) { - updateBytesRead() - } else if (split.inputSplit.value.isInstanceOf[FileSplit] || - split.inputSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { -// inputMetrics.incBytesRead(split.inputSplit.value.getLength) -// blobStoreInputMetrics.foreach(_.incBytesRead(split.inputSplit.value.getLength)) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) - } - } } } } - new InterruptibleIterator[(K, V)](context, iter) + new InterruptibleIterator[(RecordIdentifier, V)](context, iter) } @@ -378,16 +339,16 @@ class Hive3RDD[K, V]( /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ @DeveloperApi def mapPartitionsWithInputSplit[U: ClassTag]( - f: (InputSplit, Iterator[(K, V)]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = { - new Hive3PartitionsWithSplitRDD(this, f, preservesPartitioning) + f: (InputSplit, Iterator[(RecordIdentifier, V)]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = { + new HiveAcidPartitionsWithSplitRDD(this, f, preservesPartitioning) } override def getPreferredLocations(split: Partition): Seq[String] = { - val hsplit = split.asInstanceOf[Hive3Partition].inputSplit.value + val hsplit = split.asInstanceOf[HiveAcidPartition].inputSplit.value val locs = hsplit match { case lsplit: InputSplitWithLocationInfo => - Hive3RDD.convertSplitLocationInfo(lsplit.getLocationInfo) + HiveAcidRDD.convertSplitLocationInfo(lsplit.getLocationInfo) case _ => None } locs.getOrElse(hsplit.getLocations.filter(_ != "localhost")) @@ -399,17 +360,25 @@ class Hive3RDD[K, V]( override def persist(storageLevel: StorageLevel): this.type = { if (storageLevel.deserialized) { - logWarning("Caching Hive3RDDs as deserialized objects usually leads to undesired" + + logWarning("Caching HiveAcidRDDs as deserialized objects usually leads to undesired" + " behavior because Hadoop's RecordReader reuses the same Writable object for all records." + " Use a map transformation to make copies of the records.") } super.persist(storageLevel) } - def getConf: Configuration = getJobConf() + def getConf: Configuration = getJobConf } -object Hive3RDD extends Logging { +object HiveAcidRDD extends Logging { + + /* + * Use of utf8Decoder inside OrcRecordUpdater is not MT safe when peforming + * getRecordReader. This leads to illlegal state exeception when called in + * parallel by multiple tasks in single executor (JVM). Synchronize !! + */ + val RECORD_READER_INIT_LOCK = new Object() + /** * Configuration's constructor is not threadsafe (see SPARK-1097 and HADOOP-10456). * Therefore, we synchronize on this lock before calling new JobConf() or new Configuration(). @@ -444,18 +413,18 @@ object Hive3RDD extends Logging { * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to * the given function rather than the index of the partition. */ - class Hive3PartitionsWithSplitRDD[U: ClassTag, T: ClassTag]( + class HiveAcidPartitionsWithSplitRDD[U: ClassTag, T: ClassTag]( prev: RDD[T], f: (InputSplit, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false) extends RDD[U](prev) { - override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None + override val partitioner: Option[Partitioner] = if (preservesPartitioning) firstParent[T].partitioner else None override def getPartitions: Array[Partition] = firstParent[T].partitions override def compute(split: Partition, context: TaskContext): Iterator[U] = { - val partition = split.asInstanceOf[Hive3Partition] + val partition = split.asInstanceOf[HiveAcidPartition] val inputSplit = partition.inputSplit.value f(inputSplit, firstParent[T].iterator(split, context)) } @@ -480,3 +449,77 @@ object Hive3RDD extends Logging { }) } } + +/** + * Borrowed from org.apache.spark.util.NextIterator + */ +private abstract class NextIterator[U] extends Iterator[U] { + + private var gotNext = false + private var nextValue: U = _ + private var closed = false + protected var finished = false + + /** + * Method for subclasses to implement to provide the next element. + * + * If no next element is available, the subclass should set `finished` + * to `true` and may return any value (it will be ignored). + * + * This convention is required because `null` may be a valid value, + * and using `Option` seems like it might create unnecessary Some/None + * instances, given some iterators might be called in a tight loop. + * + * @return U, or set 'finished' when done + */ + protected def getNext(): U + + /** + * Method for subclasses to implement when all elements have been successfully + * iterated, and the iteration is done. + * + * Note: `NextIterator` cannot guarantee that `close` will be + * called because it has no control over what happens when an exception + * happens in the user code that is calling hasNext/next. + * + * Ideally you should have another try/catch, as in HadoopRDD, that + * ensures any resources are closed should iteration fail. + */ + protected def close() + + /** + * Calls the subclass-defined close method, but only once. + * + * Usually calling `close` multiple times should be fine, but historically + * there have been issues with some InputFormats throwing exceptions. + */ + def closeIfNeeded() { + if (!closed) { + // Note: it's important that we set closed = true before calling close(), since setting it + // afterwards would permit us to call close() multiple times if close() threw an exception. + closed = true + close() + } + } + + override def hasNext: Boolean = { + if (!finished) { + if (!gotNext) { + nextValue = getNext() + if (finished) { + closeIfNeeded() + } + gotNext = true + } + } + !finished + } + + override def next(): U = { + if (!hasNext) { + throw new NoSuchElementException("End of stream") + } + gotNext = false + nextValue + } +} diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/rdd/AcidLockUnionRDD.scala b/src/main/scala/com/qubole/spark/hiveacid/rdd/HiveAcidUnionRDD.scala similarity index 65% rename from src/main/scala/com/qubole/spark/datasources/hiveacid/rdd/AcidLockUnionRDD.scala rename to src/main/scala/com/qubole/spark/hiveacid/rdd/HiveAcidUnionRDD.scala index 9a639cc..cb23f17 100644 --- a/src/main/scala/com/qubole/spark/datasources/hiveacid/rdd/AcidLockUnionRDD.scala +++ b/src/main/scala/com/qubole/spark/hiveacid/rdd/HiveAcidUnionRDD.scala @@ -17,23 +17,27 @@ * limitations under the License. */ -package com.qubole.spark.datasources.hiveacid.rdd +package com.qubole.spark.hiveacid.rdd + +import scala.reflect.ClassTag -import com.qubole.spark.datasources.hiveacid.HiveAcidState import org.apache.spark._ import org.apache.spark.rdd.{RDD, UnionRDD} -import scala.reflect.ClassTag +/** + * A Hive3RDD is created for each of the hive partition of the table. But at the end the buildScan + * is supposed to return only 1 RDD for entire table. So we have to create UnionRDD for it. + * + * This class extends UnionRDD and makes sure that we acquire read lock once for all the + * partitions of the table -class AcidLockUnionRDD[T: ClassTag]( + * @param sc - sparkContext + * @param rddSeq - underlying partition RDDs + */ +private[hiveacid] class HiveAcidUnionRDD[T: ClassTag]( sc: SparkContext, - rddSeq: Seq[RDD[T]], - partitionList: Seq[String], - @transient val acidState: HiveAcidState) extends UnionRDD[T](sc, rddSeq) { - + rddSeq: Seq[RDD[T]]) extends UnionRDD[T](sc, rddSeq) { override def getPartitions: Array[Partition] = { - // Initialize the ACID state here to get the write-ids to read - acidState.beginRead super.getPartitions } } diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/rdd/.gitignore b/src/main/scala/com/qubole/spark/hiveacid/reader/.gitignore similarity index 100% rename from src/main/scala/com/qubole/spark/datasources/hiveacid/rdd/.gitignore rename to src/main/scala/com/qubole/spark/hiveacid/reader/.gitignore diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/HiveAcidDataSource.scala b/src/main/scala/com/qubole/spark/hiveacid/reader/Reader.scala similarity index 63% rename from src/main/scala/com/qubole/spark/datasources/hiveacid/HiveAcidDataSource.scala rename to src/main/scala/com/qubole/spark/hiveacid/reader/Reader.scala index 54c9f98..172c99c 100644 --- a/src/main/scala/com/qubole/spark/datasources/hiveacid/HiveAcidDataSource.scala +++ b/src/main/scala/com/qubole/spark/hiveacid/reader/Reader.scala @@ -17,27 +17,17 @@ * limitations under the License. */ -package com.qubole.spark.datasources.hiveacid +package com.qubole.spark.hiveacid.reader -import org.apache.spark.internal.Logging -import org.apache.spark.sql._ -import org.apache.spark.sql.sources._ +import com.qubole.spark.hiveacid.hive.HiveAcidMetadata -class HiveAcidDataSource - extends RelationProvider - with DataSourceRegister - with Logging { +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - new HiveAcidRelation( - sqlContext, - parameters - ) - } - - override def shortName(): String = { - HiveAcidUtils.NAME - } +private[reader] trait Reader { + def makeRDDForTable(hiveAcidMetadata: HiveAcidMetadata): RDD[InternalRow] + def makeRDDForPartitionedTable(hiveAcidMetadata: HiveAcidMetadata, + partitions: Seq[ReaderPartition]): RDD[InternalRow] } + +private[reader] case class ReaderPartition(ptn: Any) diff --git a/src/main/scala/com/qubole/spark/hiveacid/reader/ReaderOptions.scala b/src/main/scala/com/qubole/spark/hiveacid/reader/ReaderOptions.scala new file mode 100644 index 0000000..3f5ce3a --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/reader/ReaderOptions.scala @@ -0,0 +1,36 @@ +/* + * Copyright 2019 Qubole, Inc. All rights reserved. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid.reader + +import com.qubole.spark.hiveacid.ReadConf +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.sources.Filter + +/** + * Reader options which will be serialized and sent to each executor + */ +private[hiveacid] class ReaderOptions(val hadoopConf: Configuration, + val partitionAttributes: Seq[Attribute], + val requiredAttributes: Seq[Attribute], + val dataFilters: Array[Filter], + val requiredNonPartitionedColumns: Array[String], + val readConf: ReadConf) extends Serializable diff --git a/src/main/scala/com/qubole/spark/hiveacid/reader/TableReader.scala b/src/main/scala/com/qubole/spark/hiveacid/reader/TableReader.scala new file mode 100644 index 0000000..32a2dc2 --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/reader/TableReader.scala @@ -0,0 +1,122 @@ +/* + * Copyright 2019 Qubole, Inc. All rights reserved. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid.reader + +import com.qubole.spark.hiveacid.{HiveAcidOperation, ReadConf} +import com.qubole.spark.hiveacid.transaction._ +import com.qubole.spark.hiveacid.hive.{HiveAcidMetadata, HiveConverter} +import com.qubole.spark.hiveacid.reader.hive.{HiveAcidReader, HiveAcidReaderOptions} + +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.sources.Filter + +/** + * Table reader object + * + * @param sparkSession - Spark session + * @param curTxn - Transaction object to acquire locks. + * @param hiveAcidMetadata - Hive acid table for which read is to be performed. + */ +private[hiveacid] class TableReader(sparkSession: SparkSession, + curTxn: HiveAcidTxn, + hiveAcidMetadata: HiveAcidMetadata) extends Logging { + + def getRdd(requiredColumns: Array[String], + filters: Array[Filter], + readConf: ReadConf): RDD[Row] = { + + + val rowIdColumnSet = hiveAcidMetadata.rowIdSchema.fields.map(_.name).toSet + val requiredColumnsWithoutRowId = requiredColumns.filterNot(rowIdColumnSet.contains) + val partitionColumnNames = hiveAcidMetadata.partitionSchema.fields.map(_.name) + val partitionedColumnSet = partitionColumnNames.toSet + + // Attributes + val requiredNonPartitionedColumns = requiredColumnsWithoutRowId.filter( + x => !partitionedColumnSet.contains(x)) + + val requiredAttributes = requiredColumnsWithoutRowId.map { + x => + val field = hiveAcidMetadata.tableSchema.fields.find(_.name == x).get + PrettyAttribute(field.name, field.dataType) + } + val partitionAttributes = hiveAcidMetadata.partitionSchema.fields.map { x => + PrettyAttribute(x.name, x.dataType) + } + + // Filters + val (partitionFilters, otherFilters) = filters.partition { predicate => + !predicate.references.isEmpty && + predicate.references.toSet.subsetOf(partitionedColumnSet) + } + val dataFilters = otherFilters.filter(_ + .references.intersect(partitionColumnNames).isEmpty + ) + + logDebug(s"total filters : ${filters.length}: " + + s"dataFilters: ${dataFilters.length} " + + s"partitionFilters: ${partitionFilters.length}") + + val hadoopConf = sparkSession.sessionState.newHadoopConf() + + logDebug(s"sarg.pushdown: ${hadoopConf.get("sarg.pushdown")}," + + s"hive.io.file.readcolumn.names: ${hadoopConf.get("hive.io.file.readcolumn.names")}, " + + s"hive.io.file.readcolumn.ids: ${hadoopConf.get("hive.io.file.readcolumn.ids")}") + + val readerOptions = new ReaderOptions(hadoopConf, + partitionAttributes, + requiredAttributes, + dataFilters, + requiredNonPartitionedColumns, + readConf) + + val hiveAcidReaderOptions= HiveAcidReaderOptions.get(hiveAcidMetadata, readConf.includeRowIds) + + val (partitions, partitionList) = HiveAcidReader.getPartitions(hiveAcidMetadata, + readerOptions, + partitionFilters) + + // Acquire lock on all the partition and then create snapshot. Every time getRDD is called + // it creates a new snapshot. + // NB: partitionList is Seq if partition pruning is not enabled + curTxn.acquireLocks(hiveAcidMetadata, HiveAcidOperation.READ, partitionList) + + // Create Snapshot !!! + val curSnapshot = HiveAcidTxn.createSnapshot(curTxn, hiveAcidMetadata) + + val reader = new HiveAcidReader( + sparkSession, + readerOptions, + hiveAcidReaderOptions, + curSnapshot.validWriteIdList) + + val rdd = if (hiveAcidMetadata.isPartitioned) { + reader.makeRDDForPartitionedTable(hiveAcidMetadata, partitions) + } else { + reader.makeRDDForTable(hiveAcidMetadata) + } + + + rdd.asInstanceOf[RDD[Row]] + } +} diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/rdd/TableReader.scala b/src/main/scala/com/qubole/spark/hiveacid/reader/hive/HiveAcidReader.scala similarity index 55% rename from src/main/scala/com/qubole/spark/datasources/hiveacid/rdd/TableReader.scala rename to src/main/scala/com/qubole/spark/hiveacid/reader/hive/HiveAcidReader.scala index b7aecd0..7936f31 100644 --- a/src/main/scala/com/qubole/spark/datasources/hiveacid/rdd/TableReader.scala +++ b/src/main/scala/com/qubole/spark/hiveacid/reader/hive/HiveAcidReader.scala @@ -17,44 +17,39 @@ * limitations under the License. */ -package com.qubole.spark.datasources.hiveacid.rdd - -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +package com.qubole.spark.hiveacid.reader.hive import java.util import java.util.Properties +import scala.collection.JavaConverters._ + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.Output +import com.qubole.shaded.hadoop.hive.conf.HiveConf.ConfVars +import com.qubole.shaded.hadoop.hive.common.ValidWriteIdList import com.qubole.shaded.hadoop.hive.metastore.api.FieldSchema import com.qubole.shaded.hadoop.hive.metastore.api.hive_metastoreConstants._ import com.qubole.shaded.hadoop.hive.metastore.utils.MetaStoreUtils.{getColumnNamesFromFieldSchema, getColumnTypesFromFieldSchema} import com.qubole.shaded.hadoop.hive.ql.exec.Utilities +import com.qubole.shaded.hadoop.hive.ql.io.{AcidUtils, RecordIdentifier} import com.qubole.shaded.hadoop.hive.ql.metadata.{Partition => HiveJarPartition, Table => HiveTable} import com.qubole.shaded.hadoop.hive.ql.plan.TableDesc import com.qubole.shaded.hadoop.hive.serde2.Deserializer -import com.qubole.shaded.hadoop.hive.serde2.objectinspector.primitive._ import com.qubole.shaded.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} +import com.qubole.shaded.hadoop.hive.serde2.objectinspector.primitive._ +import com.qubole.spark.hiveacid.hive.HiveAcidMetadata +import com.qubole.spark.hiveacid.hive.HiveConverter +import com.qubole.spark.hiveacid.reader.{Reader, ReaderOptions, ReaderPartition} +import com.qubole.spark.hiveacid.rdd._ +import com.qubole.spark.hiveacid.util._ +import org.apache.commons.codec.binary.Base64 import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, PathFilter} -import com.qubole.shaded.hadoop.hive.ql.io.AcidUtils -import com.qubole.spark.datasources.hiveacid.HiveAcidState -import com.qubole.spark.datasources.hiveacid.util.{EmptyRDD, SerializableConfiguration, Util} +import org.apache.hadoop.hive.serde2.ColumnProjectionUtils import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} + import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging @@ -65,65 +60,105 @@ import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.hive.Hive3Inspectors import org.apache.spark.unsafe.types.UTF8String -import scala.collection.JavaConverters._ - /** - * A trait for subclasses that handle table scans. - */ -sealed trait TableReader { - def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] - - def makeRDDForPartitionedTable(partitions: Seq[HiveJarPartition]): RDD[InternalRow] -} - + * Helper class for scanning tables stored in Hadoop - e.g., to read + * Hive tables that reside in the data warehouse directory. + * @param sparkSession - spark session + * @param readerOptions - reader options for creating RDD + * @param hiveAcidOptions - hive related reader options for creating RDD + * @param validWriteIds - validWriteIds + */ +private[reader] class HiveAcidReader(sparkSession: SparkSession, + readerOptions: ReaderOptions, + hiveAcidOptions: HiveAcidReaderOptions, + validWriteIds: ValidWriteIdList) -/** - * Helper class for scanning tables stored in Hadoop - e.g., to read Hive tables that reside in the - * data warehouse directory. - */ -class HiveTableReader( - @transient private val attributes: Seq[Attribute], - @transient private val partitionKeys: Seq[Attribute], - @transient private val tableDesc: TableDesc, - @transient private val sparkSession: SparkSession, - @transient private val acidState: HiveAcidState, - hadoopConf: Configuration) - extends TableReader with CastSupport with Logging { +extends CastSupport with Reader with Logging { private val _minSplitsPerRDD = if (sparkSession.sparkContext.isLocal) { - 0 // will splitted based on block by default. + 0 // will be split based on block by default. } else { - math.max(hadoopConf.getInt("mapreduce.job.maps", 1), + math.max(readerOptions.hadoopConf.getInt("mapreduce.job.maps", 1), sparkSession.sparkContext.defaultMinPartitions) } SparkHadoopUtil.get.appendS3AndSparkHadoopConfigurations( - sparkSession.sparkContext.getConf, hadoopConf) + sparkSession.sparkContext.getConf, readerOptions.hadoopConf) private val _broadcastedHadoopConf = - sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + sparkSession.sparkContext.broadcast(new SerializableConfiguration(readerOptions.hadoopConf)) override def conf: SQLConf = sparkSession.sessionState.conf - override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] = + /** + * @param hiveAcidMetadata - hive acid metadata for underlying table + * @return - Returns RDD on top of non partitioned hive acid table and list of partitionNames empty list + * for entire table + */ + def makeRDDForTable(hiveAcidMetadata: HiveAcidMetadata): RDD[InternalRow] = { + val hiveTable = hiveAcidMetadata.hTable + + // Push Down Predicate + if (readerOptions.readConf.predicatePushdownEnabled) { + setPushDownFiltersInHadoopConf(readerOptions.hadoopConf, + hiveAcidMetadata, + readerOptions.dataFilters) + } + + // Set Required column. + setRequiredColumnsInHadoopConf(readerOptions.hadoopConf, + hiveAcidMetadata, + readerOptions.requiredNonPartitionedColumns) + + logDebug(s"sarg.pushdown: " + + s"${readerOptions.hadoopConf.get("sarg.pushdown")}," + + s"hive.io.file.readcolumn.names: " + + s"${readerOptions.hadoopConf.get("hive.io.file.readcolumn.names")}, " + + s"hive.io.file.readcolumn.ids: " + + s"${readerOptions.hadoopConf.get("hive.io.file.readcolumn.ids")}") + makeRDDForTable( hiveTable, - Util.classForName(tableDesc.getSerdeClassName, - true).asInstanceOf[Class[Deserializer]] + Util.classForName(hiveAcidOptions.tableDesc.getSerdeClassName, + loadShaded = true).asInstanceOf[Class[Deserializer]] ) + } /** - * Creates a Hadoop RDD to read data from the target table's data directory. Returns a transformed - * RDD that contains deserialized rows. - * - * @param hiveTable Hive metadata for the table being scanned. - * @param deserializerClass Class of the SerDe used to deserialize Writables read from Hadoop. - */ - def makeRDDForTable( - hiveTable: HiveTable, - deserializerClass: Class[_ <: Deserializer]): RDD[InternalRow] = { + * @param hiveAcidMetadata - hive acid metadata of underlying table + * @param partitions - partitions for the table + * + * @return - Returns RDD on top of partitioned hive acid table + */ + def makeRDDForPartitionedTable(hiveAcidMetadata: HiveAcidMetadata, + partitions: Seq[ReaderPartition]): RDD[InternalRow] = { + + val partitionToDeserializer = partitions.map(p => p.ptn.asInstanceOf[HiveJarPartition]).map { + part => + val deserializerClassName = part.getTPartition.getSd.getSerdeInfo.getSerializationLib + val deserializer = Util.classForName(deserializerClassName, loadShaded = true) + .asInstanceOf[Class[Deserializer]] + (part, deserializer) + }.toMap + + makeRDDForPartitionedTable(partitionToDeserializer, + filterOpt = None) + } + + /** + * Creates a Hadoop RDD to read data from the target table's data directory. + * Returns a transformed RDD that contains deserialized rows. + * + * @param hiveTable Hive metadata for the table being scanned. + * @param deserializerClass Class of the SerDe used to deserialize Writables read from Hadoop. + */ + private def makeRDDForTable(hiveTable: HiveTable, + deserializerClass: Class[_ <: Deserializer]): RDD[InternalRow] = { assert(!hiveTable.isPartitioned, "makeRDDForTable() cannot be called on a partitioned table, since input formats may " + @@ -131,70 +166,64 @@ class HiveTableReader( // Create local references to member variables, so that the entire `this` object won't be // serialized in the closure below. - val localTableDesc = tableDesc + val localTableDesc = hiveAcidOptions.tableDesc val broadcastedHadoopConf = _broadcastedHadoopConf - val tablePath = hiveTable.getPath - // logDebug("Table input: %s".format(tablePath)) val ifcName = hiveTable.getInputFormatClass.getName - val ifc = Util.classForName(ifcName, true).asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] + val ifc = Util.classForName(ifcName, loadShaded = true) + .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] val hiveRDD = createRddForTable(localTableDesc, hiveTable.getSd.getCols, - hiveTable.getParameters, tablePath.toString, true, ifc) - - val attrsWithIndex = attributes.zipWithIndex - val mutableRow = new SpecificInternalRow(attributes.map(_.dataType)) + hiveTable.getParameters, tablePath.toString, ifc) + + val attrsWithIndex = readerOptions.requiredAttributes.zipWithIndex + val localRowIdSchema: Option[StructType] = hiveAcidOptions.rowIdSchema + val outputRowDataTypes = (localRowIdSchema match { + case Some(schema) => + Seq(schema) + case None => + Seq() + }) ++ readerOptions.requiredAttributes.map(_.dataType) + val mutableRow = new SpecificInternalRow(outputRowDataTypes) + + val mutableRowRecordIds = localRowIdSchema match { + case Some(schema) => Some(new SpecificInternalRow(schema.fields.map(_.dataType))) + case None => None + } val deserializedHiveRDD = hiveRDD.mapPartitions { iter => val hconf = broadcastedHadoopConf.value.value val deserializer = deserializerClass.newInstance() deserializer.initialize(hconf, localTableDesc.getProperties) - HiveTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow, deserializer) + HiveAcidReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow, + mutableRowRecordIds, + deserializer) } - - new AcidLockUnionRDD[InternalRow](sparkSession.sparkContext, Seq(deserializedHiveRDD), - Seq(), acidState) - } - - override def makeRDDForPartitionedTable(partitions: Seq[HiveJarPartition]): RDD[InternalRow] = { - val partitionToDeserializer = partitions.map{ - part => - val deserializerClassName = part.getTPartition.getSd.getSerdeInfo.getSerializationLib - val deserializer = Util.classForName(deserializerClassName, true) - .asInstanceOf[Class[Deserializer]] - (part, deserializer) - }.toMap - makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None) + new HiveAcidUnionRDD[InternalRow](sparkSession.sparkContext, Seq(deserializedHiveRDD)) } /** - * Create a Hive3RDD for every partition key specified in the query. - * - * @param partitionToDeserializer Mapping from a Hive Partition metadata object to the SerDe - * class to use to deserialize input Writables from the corresponding partition. - * @param filterOpt If defined, then the filter is used to reject files contained in the data - * subdirectory of each partition being read. If None, then all files are accepted. - */ - def makeRDDForPartitionedTable( - partitionToDeserializer: Map[HiveJarPartition, Class[_ <: Deserializer]], - filterOpt: Option[PathFilter]): RDD[InternalRow] = { - - val partitionStrings = partitionToDeserializer.map { case (partition, _) => - // val partKeysFieldSchema = partition.getTable.getPartitionKeys.asScala - // partKeysFieldSchema.map(_.getName).mkString("/") - partition.getName - }.toSeq + * Create a HiveAcidRDD for every partition key specified in the query. + * + * @param partitionToDeserializer Mapping from a Hive Partition metadata object to the SerDe + * class to use to deserialize input Writables from the corresponding partition. + * @param filterOpt If defined, then the filter is used to reject files contained in the data + * subdirectory of each partition being read. If None, then all files are accepted. + */ + private def makeRDDForPartitionedTable( + partitionToDeserializer: Map[HiveJarPartition, Class[_ <: Deserializer]], + filterOpt: Option[PathFilter]): RDD[InternalRow] = { val hivePartitionRDDs = partitionToDeserializer.map { case (partition, partDeserializer) => val partProps = partition.getMetadataFromPartitionSchema val partPath = partition.getDataLocation val inputPathStr = applyFilterIfNeeded(partPath, filterOpt) val ifcString = partition.getTPartition.getSd.getInputFormat - val ifc = Util.classForName(ifcString, true) + val ifc = Util.classForName(ifcString, loadShaded = true) .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] // Get partition field info val partSpec = partition.getSpec - val partCols = partition.getTable.getPartitionKeys.asScala.map(_.getName).toSeq + val partCols = partition.getTable.getPartitionKeys.asScala.map(_.getName) // 'partValues[i]' contains the value for the partitioning column at 'partCols[i]'. val partValues = if (partSpec == null) { Array.fill(partCols.size)(new String) @@ -204,31 +233,51 @@ class HiveTableReader( val broadcastedHiveConf = _broadcastedHadoopConf val localDeserializer = partDeserializer - val mutableRow = new SpecificInternalRow(attributes.map(_.dataType)) - // Splits all attributes into two groups, partition key attributes and those that are not. - // Attached indices indicate the position of each attribute in the output schema. + val localRowIdSchema: Option[StructType] = hiveAcidOptions.rowIdSchema + val outputRowDataTypes = (localRowIdSchema match { + case Some(schema) => + Seq(schema) + case None => + Seq() + }) ++ readerOptions.requiredAttributes.map(_.dataType) + val mutableRow = new SpecificInternalRow(outputRowDataTypes) + val mutableRowRecordIds = localRowIdSchema match { + case Some(schema) => Some(new SpecificInternalRow(schema.fields.map(_.dataType))) + case None => None + } + + // Splits all attributes into two groups, partition key attributes and those + // that are not. Attached indices indicate the position of each attribute in + // the output schema. val (partitionKeyAttrs, nonPartitionKeyAttrs) = - attributes.zipWithIndex.partition { case (attr, _) => - partitionKeys.contains(attr) + readerOptions.requiredAttributes.zipWithIndex.partition { case (attr, _) => + readerOptions.partitionAttributes.contains(attr) } def fillPartitionKeys(rawPartValues: Array[String], row: InternalRow): Unit = { + val offset = localRowIdSchema match { + case Some(_) => + 1 + case None => + 0 + } partitionKeyAttrs.foreach { case (attr, ordinal) => - val partOrdinal = partitionKeys.indexOf(attr) - row(ordinal) = cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) + val partOrdinal = readerOptions.partitionAttributes.indexOf(attr) + row(offset + ordinal) = cast( + Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) } } // Fill all partition keys to the given MutableRow object fillPartitionKeys(partValues, mutableRow) - val tableProperties = tableDesc.getProperties + val tableProperties = hiveAcidOptions.tableDesc.getProperties // Create local references so that the outer object isn't serialized. - val localTableDesc = tableDesc - createRddForTable(localTableDesc, partition.getCols, partition.getTable.getParameters, inputPathStr, false, - ifc).mapPartitions { iter => + val localTableDesc = hiveAcidOptions.tableDesc + createRddForTable(localTableDesc, partition.getCols, + partition.getTable.getParameters, inputPathStr, ifc).mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val deserializer = localDeserializer.newInstance() // SPARK-13709: For SerDes like AvroSerDe, some essential information (e.g. Avro schema @@ -243,21 +292,23 @@ class HiveTableReader( deserializer.initialize(hconf, props) // get the table deserializer val tableSerDeClassName = localTableDesc.getSerdeClassName - val tableSerDe = Util.classForName(tableSerDeClassName, true).newInstance().asInstanceOf[Deserializer] + val tableSerDe = Util.classForName(tableSerDeClassName, + loadShaded = true).newInstance().asInstanceOf[Deserializer] tableSerDe.initialize(hconf, localTableDesc.getProperties) // fill the non partition key attributes - HiveTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, - mutableRow, tableSerDe) + HiveAcidReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, + mutableRow, + mutableRowRecordIds, + tableSerDe) } }.toSeq // Even if we don't use any partitions, we still need an empty RDD - if (hivePartitionRDDs.size == 0) { + if (hivePartitionRDDs.isEmpty) { new EmptyRDD[InternalRow](sparkSession.sparkContext) } else { - new AcidLockUnionRDD[InternalRow](hivePartitionRDDs(0).context, hivePartitionRDDs, - partitionStrings, acidState) + new HiveAcidUnionRDD[InternalRow](hivePartitionRDDs.head.context, hivePartitionRDDs) } } @@ -268,7 +319,7 @@ class HiveTableReader( private def applyFilterIfNeeded(path: Path, filterOpt: Option[PathFilter]): String = { filterOpt match { case Some(filter) => - val fs = path.getFileSystem(hadoopConf) + val fs = path.getFileSystem(readerOptions.hadoopConf) val filteredFiles = fs.listStatus(path, filter).map(_.getPath.toString) filteredFiles.mkString(",") case None => path.toString @@ -276,24 +327,25 @@ class HiveTableReader( } /** - * Creates a Hive3RDD based on the broadcasted HiveConf and other job properties that will be + * Creates a HiveAcidRDD based on the broadcasted HiveConf and other job properties that will be * applied locally on each slave. */ private def createRddForTable(tableDesc: TableDesc, cols: util.List[FieldSchema], tableParameters: util.Map[String, String], path: String, - acquireLocks: Boolean, inputFormatClass: Class[InputFormat[Writable, Writable]] - ): RDD[Writable] = { + ): RDD[(RecordIdentifier, Writable)] = { val colNames = getColumnNamesFromFieldSchema(cols) val colTypes = getColumnTypesFromFieldSchema(cols) - val initializeJobConfFunc = HiveTableReader.initializeLocalJobConfFunc(path, tableDesc, tableParameters, + val initializeJobConfFunc = HiveAcidReader.initializeLocalJobConfFunc( + path, tableDesc, tableParameters, colNames, colTypes) _ - val rdd = new Hive3RDD( + val rdd = new HiveAcidRDD( sparkSession.sparkContext, - acidState, + validWriteIds, + hiveAcidOptions.isFullAcidTable, _broadcastedHadoopConf.asInstanceOf[Broadcast[SerializableConfiguration]], Some(initializeJobConfFunc), inputFormatClass, @@ -301,18 +353,65 @@ class HiveTableReader( classOf[Writable], _minSplitsPerRDD) - // Only take the value (skip the key) because Hive works only with values. - rdd.map(_._2) + rdd + } + + private def setRequiredColumnsInHadoopConf(conf: Configuration, + acidTableMetadata: HiveAcidMetadata, + requiredColumns: Seq[String]): Unit = { + val dataCols: Seq[String] = acidTableMetadata.dataSchema.fields.map(_.name) + val requiredColumnIndexes = requiredColumns.map(a => dataCols.indexOf(a): Integer) + val (sortedIDs, sortedNames) = requiredColumnIndexes.zip(requiredColumns).sorted.unzip + conf.set(ColumnProjectionUtils.READ_ALL_COLUMNS, "false") + conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, sortedNames.mkString(",")) + conf.set(ColumnProjectionUtils.READ_COLUMN_IDS_CONF_STR, sortedIDs.mkString(",")) + } + + private def setPushDownFiltersInHadoopConf(conf: Configuration, + acidTableMetadata: HiveAcidMetadata, + dataFilters: Array[Filter]): Unit = { + HiveAcidSearchArgument.build(acidTableMetadata.dataSchema, dataFilters).foreach { f => + def toKryo(obj: com.qubole.shaded.hadoop.hive.ql.io.sarg.SearchArgument): String = { + val out = new Output(4 * 1024, 10 * 1024 * 1024) + new Kryo().writeObject(out, obj) + out.close() + Base64.encodeBase64String(out.toBytes) + } + + logDebug(s"searchArgument: $f") + conf.set("sarg.pushdown", toKryo(f)) + conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) + } } } -object HiveTableUtil { +private[reader] object HiveAcidReader extends Hive3Inspectors with Logging { + + def getPartitions(hiveAcidMetadata: HiveAcidMetadata, + readerOptions: ReaderOptions, + partitionFilters: Seq[Filter]): (Seq[ReaderPartition], Seq[String]) = { + + val partitions = if (hiveAcidMetadata.isPartitioned) { + val partitionPruiningFilters = if (readerOptions.readConf.metastorePartitionPruningEnabled) { + Option(HiveConverter.compileFilters(partitionFilters)) + } else { + None + } + hiveAcidMetadata.getRawPartitions(partitionPruiningFilters) + } else { + Seq() + } + + val partitionNames = partitions.map(_.getName()) + + (partitions.map(p => ReaderPartition(p)), partitionNames) + } // copied from PlanUtils.configureJobPropertiesForStorageHandler(tableDesc) // that calls Hive.get() which tries to access metastore, but it's not valid in runtime // it would be fixed in next version of hive but till then, we should use this instead - def configureJobPropertiesForStorageHandler( - tableDesc: TableDesc, conf: Configuration, input: Boolean) { + private def configureJobPropertiesForStorageHandler(tableDesc: TableDesc, + conf: Configuration, input: Boolean) { val property = tableDesc.getProperties.getProperty(META_TABLE_STORAGE) val storageHandler = com.qubole.shaded.hadoop.hive.ql.metadata.HiveUtils.getStorageHandler(conf, property) @@ -328,12 +427,10 @@ object HiveTableUtil { } } } -} -object HiveTableReader extends Hive3Inspectors with Logging { /** * Curried. After given an argument for 'path', the resulting JobConf => Unit closure is used to - * instantiate a Hive3RDD. + * instantiate a HiveAcidRDD. */ def initializeLocalJobConfFunc(path: String, tableDesc: TableDesc, tableParameters: util.Map[String, String], @@ -341,7 +438,7 @@ object HiveTableReader extends Hive3Inspectors with Logging { schemaColTypes: String)(jobConf: JobConf) { FileInputFormat.setInputPaths(jobConf, Seq[Path](new Path(path)): _*) if (tableDesc != null) { - HiveTableUtil.configureJobPropertiesForStorageHandler(tableDesc, jobConf, true) + configureJobPropertiesForStorageHandler(tableDesc, jobConf, input = true) Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) } val bufferSize = System.getProperty("spark.buffer.size", "65536") @@ -360,7 +457,7 @@ object HiveTableReader extends Hive3Inspectors with Logging { } /** - * Transform all given raw `Writable`s into `Row`s. + * Transform all given raw `(RowIdentifier, Writable)`s into `InternalRow`s. * * @param iterator Iterator of all `Writable`s to be transformed * @param rawDeser The `Deserializer` associated with the input `Writable` @@ -371,10 +468,11 @@ object HiveTableReader extends Hive3Inspectors with Logging { * @return An `Iterator[Row]` transformed from `iterator` */ def fillObject( - iterator: Iterator[Writable], + iterator: Iterator[(RecordIdentifier, Writable)], rawDeser: Deserializer, nonPartitionKeyAttrs: Seq[(Attribute, Int)], mutableRow: InternalRow, + mutableRowRecordId: Option[InternalRow], tableDeser: Deserializer): Iterator[InternalRow] = { val soi = if (rawDeser.getObjectInspector.equals(tableDeser.getObjectInspector)) { @@ -441,15 +539,27 @@ object HiveTableReader extends Hive3Inspectors with Logging { // Map each tuple to a row object iterator.map { value => - val raw = converter.convert(rawDeser.deserialize(value)) + val dataStartIndex = mutableRowRecordId match { + case Some(record) => + val recordIdentifier = value._1 + record.setLong(0, recordIdentifier.getWriteId) + record.setInt(1, recordIdentifier.getBucketProperty) + record.setLong(2, recordIdentifier.getRowId) + mutableRow.update(0, record) + 1 + case None => + 0 + } + + val raw = converter.convert(rawDeser.deserialize(value._2)) var i = 0 val length = fieldRefs.length while (i < length) { val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) if (fieldValue == null) { - mutableRow.setNullAt(fieldOrdinals(i)) + mutableRow.setNullAt(fieldOrdinals(i) + dataStartIndex) } else { - unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i) + dataStartIndex) } i += 1 } diff --git a/src/main/scala/com/qubole/spark/hiveacid/reader/hive/HiveAcidReaderOptions.scala b/src/main/scala/com/qubole/spark/hiveacid/reader/hive/HiveAcidReaderOptions.scala new file mode 100644 index 0000000..c788bb5 --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/reader/hive/HiveAcidReaderOptions.scala @@ -0,0 +1,39 @@ +/* + * Copyright 2019 Qubole, Inc. All rights reserved. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid.reader.hive + +import com.qubole.shaded.hadoop.hive.ql.plan.TableDesc +import com.qubole.spark.hiveacid.hive.HiveAcidMetadata +import org.apache.spark.sql.types.StructType + +private[reader] class HiveAcidReaderOptions(val tableDesc: TableDesc, + val rowIdSchema: Option[StructType], + val isFullAcidTable: Boolean) + +private[reader] object HiveAcidReaderOptions { + def get(hiveAcidMetadata: HiveAcidMetadata, includeRowIds: Boolean): HiveAcidReaderOptions = { + val rowIdSchema = if (includeRowIds) { + Option(hiveAcidMetadata.rowIdSchema) + } else { + None + } + new HiveAcidReaderOptions(hiveAcidMetadata.tableDesc, rowIdSchema, hiveAcidMetadata.isFullAcidTable) + } +} diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/orc/OrcFilters.scala b/src/main/scala/com/qubole/spark/hiveacid/reader/hive/HiveAcidSearchArgument.scala similarity index 80% rename from src/main/scala/com/qubole/spark/datasources/hiveacid/orc/OrcFilters.scala rename to src/main/scala/com/qubole/spark/hiveacid/reader/hive/HiveAcidSearchArgument.scala index 5b8b80e..0a1ba9b 100644 --- a/src/main/scala/com/qubole/spark/datasources/hiveacid/orc/OrcFilters.scala +++ b/src/main/scala/com/qubole/spark/hiveacid/reader/hive/HiveAcidSearchArgument.scala @@ -17,47 +17,50 @@ * limitations under the License. */ -package com.qubole.spark.datasources.hiveacid.orc +package com.qubole.spark.hiveacid.reader.hive import com.qubole.shaded.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} import com.qubole.shaded.hadoop.hive.ql.io.sarg.SearchArgument.Builder import com.qubole.shaded.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import com.qubole.shaded.hadoop.hive.serde2.io.HiveDecimalWritable + import org.apache.spark.sql.sources.{And, Filter} import org.apache.spark.sql.types._ /** - * Helper object for building ORC `SearchArgument`s, which are used for ORC predicate push-down. - * - * Due to limitation of ORC `SearchArgument` builder, we had to end up with a pretty weird double- - * checking pattern when converting `And`/`Or`/`Not` filters. - * - * An ORC `SearchArgument` must be built in one pass using a single builder. For example, you can't - * build `a = 1` and `b = 2` first, and then combine them into `a = 1 AND b = 2`. This is quite - * different from the cases in Spark SQL or Parquet, where complex filters can be easily built using - * existing simpler ones. - * - * The annoying part is that, `SearchArgument` builder methods like `startAnd()`, `startOr()`, and - * `startNot()` mutate internal state of the builder instance. This forces us to translate all - * convertible filters with a single builder instance. However, before actually converting a filter, - * we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible filter is - * found, we may already end up with a builder whose internal state is inconsistent. - * - * For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and then - * try to convert its children. Say we convert `left` child successfully, but find that `right` - * child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is inconsistent - * now. - * - * The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their - * children with brand new builders, and only do the actual conversion with the right builder - * instance when the children are proven to be convertible. - * - * P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. Usage of - * builder methods mentioned above can only be found in test code, where all tested filters are - * known to be convertible. - */ -object OrcFilters { - def buildTree(filters: Seq[Filter]): Option[Filter] = { + * Copied from org.apache.spark.sql.execution.datasources.orc.OrcFilters + * + * Helper object for building ORC `SearchArgument`s, which are used for ORC predicate push-down. + * + * Due to limitation of ORC `SearchArgument` builder, we had to end up with a pretty weird double- + * checking pattern when converting `And`/`Or`/`Not` filters. + * + * An ORC `SearchArgument` must be built in one pass using a single builder. For example, you can't + * build `a = 1` and `b = 2` first, and then combine them into `a = 1 AND b = 2`. This is quite + * different from the cases in Spark SQL or Parquet, where complex filters can be easily built using + * existing simpler ones. + * + * The annoying part is that, `SearchArgument` builder methods like `startAnd()`, `startOr()`, and + * `startNot()` mutate internal state of the builder instance. This forces us to translate all + * convertible filters with a single builder instance. However, before actually converting a filter, + * we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible filter is + * found, we may already end up with a builder whose internal state is inconsistent. + * + * For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and then + * try to convert its children. Say we convert `left` child successfully, but find that `right` + * child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is inconsistent + * now. + * + * The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their + * children with brand new builders, and only do the actual conversion with the right builder + * instance when the children are proven to be convertible. + * + * P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. Usage of + * builder methods mentioned above can only be found in test code, where all tested filters are + * known to be convertible. + */ +private[hive] object HiveAcidSearchArgument { + private def buildTree(filters: Seq[Filter]): Option[Filter] = { filters match { case Seq() => None case Seq(filter) => Some(filter) @@ -70,7 +73,7 @@ object OrcFilters { // Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters // in order to distinguish predicate pushdown for nested columns. - private def quoteAttributeNameIfNeeded(name: String) : String = { + private def quoteAttributeNameIfNeeded(name: String) : String = { if (!name.contains("`") && name.contains(".")) { s"`$name`" } else { @@ -78,27 +81,6 @@ object OrcFilters { } } - /** - * Create ORC filter as a SearchArgument instance. - */ - def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap - - // First, tries to convert each filter individually to see whether it's convertible, and then - // collect all convertible ones to build the final `SearchArgument`. - val convertibleFilters = for { - filter <- filters - _ <- buildSearchArgument(dataTypeMap, filter, newBuilder) - } yield filter - - for { - // Combines all convertible filters using `And` to produce a single conjunction - conjunction <- buildTree(convertibleFilters) - // Then tries to build a single ORC `SearchArgument` for the conjunction predicate - builder <- buildSearchArgument(dataTypeMap, conjunction, newBuilder) - } yield builder.build() - } - /** * Return true if this is a searchable type in ORC. * Both CharType and VarcharType are cleaned at AstBuilder. @@ -238,4 +220,26 @@ object OrcFilters { case _ => None } } + + /** + * Create filters as a SearchArgument instance. + */ + def build(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { + val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + + // First, tries to convert each filter individually to see whether it's convertible, and then + // collect all convertible ones to build the final `SearchArgument`. + val convertibleFilters = for { + filter <- filters + _ <- buildSearchArgument(dataTypeMap, filter, newBuilder) + } yield filter + + for { + // Combines all convertible filters using `And` to produce a single conjunction + conjunction <- buildTree(convertibleFilters) + // Then tries to build a single ORC `SearchArgument` for the conjunction predicate + builder <- buildSearchArgument(dataTypeMap, conjunction, newBuilder) + } yield builder.build() + } + } diff --git a/src/main/scala/com/qubole/spark/hiveacid/transaction/HiveAcidTxn.scala b/src/main/scala/com/qubole/spark/hiveacid/transaction/HiveAcidTxn.scala new file mode 100644 index 0000000..76a9521 --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/transaction/HiveAcidTxn.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid.transaction + +import java.util.concurrent.atomic.AtomicBoolean + +import com.qubole.shaded.hadoop.hive.common.{ValidTxnList, ValidWriteIdList} +import com.qubole.spark.hiveacid.{HiveAcidErrors, HiveAcidOperation} +import com.qubole.spark.hiveacid.hive.HiveAcidMetadata + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession + +/** + * Hive Acid Transaction object. + * @param sparkSession: Spark Session + */ +class HiveAcidTxn(sparkSession: SparkSession) extends Logging { + + HiveAcidTxn.setUpTxnManager(sparkSession) + + // txn ID + protected var id: Long = -1 + protected var validTxnList: ValidTxnList = _ + private [hiveacid] val isClosed: AtomicBoolean = new AtomicBoolean(true) + + private def setTxn(id: Long, txns:ValidTxnList): Unit = { + this.id = id + this.validTxnList = txns + isClosed.set(false) + } + + private def unsetTxn(): Unit = { + this.id = -1 + this.validTxnList = null + isClosed.set(true) + } + + override def toString: String = s"""{"id":"$id","validTxns":"$validTxnList"}""" + + /** + * Public API to being transaction. + */ + def begin(): Unit = synchronized { + if (!isClosed.get) { + throw HiveAcidErrors.txnAlreadyOpen(id) + } + val newId = HiveAcidTxn.txnManager.beginTxn(this) + val txnList = HiveAcidTxn.txnManager.getValidTxns(Some(newId)) + setTxn(newId, txnList) + // Set it for thread for all future references. + HiveAcidTxn.threadLocal.set(this) + logDebug(s"Begin transaction $this") + } + + /** + * Public API to end transaction + * @param abort true if transaction is aborted + */ + def end(abort: Boolean = false): Unit = synchronized { + if (isClosed.get) { + throw HiveAcidErrors.txnAlreadyClosed(id) + } + + logDebug(s"End transaction $this abort = $abort") + // NB: Unset it for thread proactively invariant of + // underlying call fails or succeeds. + HiveAcidTxn.threadLocal.set(null) + HiveAcidTxn.txnManager.endTxn(id, abort) + unsetTxn() + } + + private[hiveacid] def acquireLocks(hiveAcidMetadata: HiveAcidMetadata, + operationType: HiveAcidOperation.OperationType, + partitionNames: Seq[String]): Unit = { + if (isClosed.get()) { + logError(s"Transaction already closed $this") + throw HiveAcidErrors.txnAlreadyClosed(id) + } + HiveAcidTxn.txnManager.acquireLocks(id, hiveAcidMetadata.dbName, + hiveAcidMetadata.tableName, operationType, partitionNames) + } + // Public Interface + def txnId: Long = id +} + +object HiveAcidTxn extends Logging { + + val threadLocal = new ThreadLocal[HiveAcidTxn] + + // Helper function to create snapshot. + private[hiveacid] def createSnapshot(txn: HiveAcidTxn, hiveAcidMetadata: HiveAcidMetadata): HiveAcidTableSnapshot = { + val currentWriteId = txnManager.getCurrentWriteId(txn.txnId, + hiveAcidMetadata.dbName, hiveAcidMetadata.tableName) + val validWriteIdList = if (txn.txnId == - 1) { + throw HiveAcidErrors.tableWriteIdRequestedBeforeTxnStart (hiveAcidMetadata.fullyQualifiedName) + } else { + txnManager.getValidWriteIds(txn.txnId, txn.validTxnList ,hiveAcidMetadata.fullyQualifiedName) + } + HiveAcidTableSnapshot(validWriteIdList, currentWriteId) + } + + // Txn manager is connection to HMS. Use single instance of it + var txnManager: HiveAcidTxnManager = _ + private def setUpTxnManager(sparkSession: SparkSession): Unit = synchronized { + if (txnManager == null) { + txnManager = new HiveAcidTxnManager(sparkSession) + } + } + + /** + * Creates read or write transaction based on user request. + * + * @param sparkSession Create a new hive Acid transaction + * @return + */ + def createTransaction(sparkSession: SparkSession): HiveAcidTxn = { + setUpTxnManager(sparkSession) + new HiveAcidTxn(sparkSession) + } + + /** + * Given a transaction id return the HiveAcidTxn object. Raise exception if not found. + * @return + */ + def currentTxn(): HiveAcidTxn = { + threadLocal.get() + } +} + +private[hiveacid] case class HiveAcidTableSnapshot(validWriteIdList: ValidWriteIdList, currentWriteId: Long) diff --git a/src/main/scala/com/qubole/spark/hiveacid/transaction/HiveAcidTxnManager.scala b/src/main/scala/com/qubole/spark/hiveacid/transaction/HiveAcidTxnManager.scala new file mode 100644 index 0000000..022795e --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/transaction/HiveAcidTxnManager.scala @@ -0,0 +1,362 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid.transaction + +import java.util.concurrent.{Executors, ScheduledExecutorService, ThreadFactory, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean + +import com.qubole.shaded.hadoop.hive.common.{ValidTxnList, ValidTxnWriteIdList, ValidWriteIdList} +import com.qubole.shaded.hadoop.hive.conf.HiveConf +import com.qubole.shaded.hadoop.hive.metastore.api.{DataOperationType, LockRequest, LockResponse, LockState} +import com.qubole.shaded.hadoop.hive.metastore.conf.MetastoreConf +import com.qubole.shaded.hadoop.hive.metastore.txn.TxnUtils +import com.qubole.shaded.hadoop.hive.metastore.{HiveMetaStoreClient, LockComponentBuilder, LockRequestBuilder} +import com.qubole.spark.hiveacid.datasource.HiveAcidDataSource +import com.qubole.spark.hiveacid.hive.HiveConverter +import com.qubole.spark.hiveacid.{HiveAcidErrors, HiveAcidOperation} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.SqlUtils +import com.qubole.shaded.thrift.TException +import scala.collection.JavaConversions._ +import scala.language.implicitConversions + +/** + * Txn Manager for hive acid tables. + * This takes care of creating txns, acquiring locks and sending heartbeats + * @param sparkSession - Spark session + */ +private[hiveacid] class HiveAcidTxnManager(sparkSession: SparkSession) extends Logging { + + private val hiveConf = HiveConverter.getHiveConf(sparkSession.sparkContext) + + private val heartbeatInterval = MetastoreConf.getTimeVar(hiveConf, + MetastoreConf.ConfVars.TXN_TIMEOUT, TimeUnit.MILLISECONDS) / 2 + + private lazy val client: HiveMetaStoreClient = new HiveMetaStoreClient( + hiveConf, null, false) + + private lazy val heartBeaterClient: HiveMetaStoreClient = + new HiveMetaStoreClient(hiveConf, null, false) + + // FIXME: Use thread pool so that we don't create multiple threads + private val heartBeater: ScheduledExecutorService = + Executors.newSingleThreadScheduledExecutor(new ThreadFactory() { + def newThread(r: Runnable) = new HeartBeaterThread(r, "AcidDataSourceHeartBeater") + }) + heartBeater.scheduleAtFixedRate( + new HeartbeatRunnable(), + 0, + heartbeatInterval, + TimeUnit.MILLISECONDS) + + private val user: String = sparkSession.sparkContext.sparkUser + + private val shutdownInitiated: AtomicBoolean = new AtomicBoolean(true) + + /** + * Register transactions with Hive Metastore and tracks it under HiveAcidTxnManager.activeTxns + * @param txn transaction which needs to begin. + * @return Update transaction object + */ + def beginTxn(txn: HiveAcidTxn): Long = synchronized { + // 1. Open transaction + val txnId = client.openTxn(HiveAcidDataSource.NAME) + if (HiveAcidTxnManager.activeTxns.contains(txnId)) { + throw HiveAcidErrors.repeatedTxnId(txnId, HiveAcidTxnManager.activeTxns.keySet.toSeq) + } + HiveAcidTxnManager.activeTxns.put(txnId, txn) + logDebug(s"Adding txnId: $txnId to tracker") + txnId + } + + /** + * Wrapper over call to hive metastore to end transaction either with commit or abort. + * @param txnId id of transaction to be end + * @param abort true if transaction is to be aborted. + */ + def endTxn(txnId: Long, abort: Boolean = false): Unit = synchronized { + try { + // NB: Remove it from tracking before making HMS call + // which can potentially fail. + HiveAcidTxnManager.activeTxns.remove(txnId) + logDebug(s"Removing txnId: $txnId from tracker") + if (abort) { + client.abortTxns(scala.collection.JavaConversions.seqAsJavaList(Seq(txnId))) + } else { + client.commitTxn(txnId) + } + } catch { + case e: Exception => + logError(s"Failure to end txn: $txnId, presumed abort", e) + } + } + + /** + * Destroy the transaction object. Closes all the pooled connection, + * stops heartbeat and aborts all running transactions. + */ + def close(): Unit = synchronized { + + if (!shutdownInitiated.compareAndSet(false, true)) { + return + } + + // Stop the heartbeat executor + if (heartBeater != null) { + heartBeater.shutdown() + } + + // NB: caller of close itself is from heartbeater thread + // such await would be self deadlock. + // heartBeater.awaitTermination(10, TimeUnit.SECONDS) + + // abort all active transactions + HiveAcidTxnManager.activeTxns.foreach { + case (_, txn) => txn.end(true) + } + HiveAcidTxnManager.activeTxns.clear() + + // close all clients + if (client != null) { + client.close() + } + if (heartBeaterClient != null) { + heartBeaterClient.close() + } + } + + /** + * Returns current write id. + * @param txnId transaction id for which current write id is requested. + * @param dbName: Database name + * @param tableName: Table name + * @return + */ + def getCurrentWriteId(txnId: Long, dbName: String, tableName: String): Long = synchronized { + client.allocateTableWriteId(txnId, dbName, tableName) + } + + /** + * Return list of valid txn list. + * @param txnIdOpt txn id, current if None is passed. + * @return + */ + def getValidTxns(txnIdOpt: Option[Long]): ValidTxnList = synchronized { + txnIdOpt match { + case Some(id) => client.getValidTxns(id) + case None => client.getValidTxns() + } + } + + /** + * Return list of all valid write ids for the table. + * @param fullyQualifiedTableName name of the table + * @param validTxnList valid txn list snapshot. + * @return List of valid write ids + */ + def getValidWriteIds(validTxnList: ValidTxnList, + fullyQualifiedTableName: String): ValidWriteIdList = synchronized { + getValidWriteIds(None, validTxnList, fullyQualifiedTableName) + } + + /** + * Return list of all valid write ids for the table for given transactions + * @param txnId transaction id + * @param validTxnList valid txn list snapshot. + * @param fullyQualifiedTableName table name + * @return List of valid write ids + */ + def getValidWriteIds(txnId: Long, + validTxnList: ValidTxnList, + fullyQualifiedTableName: String): ValidWriteIdList = synchronized { + getValidWriteIds(Option(txnId), validTxnList, fullyQualifiedTableName) + } + + private def getValidWriteIds(txnIdOpt: Option[Long], + validTxnList: ValidTxnList, + fullyQualifiedTableName: String): ValidWriteIdList = synchronized { + val txnId = txnIdOpt match { + case Some(id) => id + case None => -1L + } + val tableValidWriteIds = client.getValidWriteIds(Seq(fullyQualifiedTableName), + validTxnList.writeToString()) + val txnWriteIds: ValidTxnWriteIdList = TxnUtils.createValidTxnWriteIdList(txnId, + tableValidWriteIds) + txnWriteIds.getTableValidWriteIdList(fullyQualifiedTableName) + } + + /** + * API to acquire locks on partitions + * @param txnId transaction id + * @param dbName: Database name + * @param tableName: Table name + * @param operationType lock type + * @param partitionNames partition names + */ + def acquireLocks(txnId: Long, + dbName: String, + tableName: String, + operationType: HiveAcidOperation.OperationType, + partitionNames: Seq[String]): Unit = synchronized { + + // Consider following sequence of event + // T1: R(x) + // T2: R(x) + // T2: W(x) + // T2: Commit + // T1: W(x) + // Because read happens with MVCC it is possible that some other transaction + // may have come and performed write. To protect against the lost write due + // to above sequence hive maintains write-set and abort conflict transaction + // optimistically at the commit time. + def addLockType(lcb: LockComponentBuilder): LockComponentBuilder = { + operationType match { + case HiveAcidOperation.INSERT_OVERWRITE => + lcb.setExclusive().setOperationType(DataOperationType.UPDATE) + case HiveAcidOperation.INSERT_INTO => + lcb.setShared().setOperationType(DataOperationType.INSERT) + case HiveAcidOperation.READ => + lcb.setShared().setOperationType(DataOperationType.SELECT) + case HiveAcidOperation.UPDATE => + lcb.setSemiShared().setOperationType(DataOperationType.UPDATE) + case HiveAcidOperation.DELETE => + lcb.setSemiShared().setOperationType(DataOperationType.DELETE) + case _ => + throw HiveAcidErrors.invalidOperationType(operationType.toString) + } + } + + def createLockRequest() = { + val requestBuilder = new LockRequestBuilder(HiveAcidDataSource.NAME) + requestBuilder.setUser(user) + requestBuilder.setTransactionId(txnId) + if (partitionNames.isEmpty) { + val lockCompBuilder = new LockComponentBuilder() + .setDbName(dbName) + .setTableName(tableName) + + requestBuilder.addLockComponent(addLockType(lockCompBuilder).build) + } else { + partitionNames.foreach(partName => { + val lockCompBuilder = new LockComponentBuilder() + .setPartitionName(partName) + .setDbName(dbName) + .setTableName(tableName) + requestBuilder.addLockComponent(addLockType(lockCompBuilder).build) + }) + } + requestBuilder.build + } + + def lock(lockReq: LockRequest): Unit = { + var nextSleep = 50L + + // FIXME: This is crazy long wait for locks. Sleep starts from 50ms and + // exponentially in power of 2 backs off to MAX_SLEEP the maximum sleep + // for unsuccessful lock acquisition time is maxNumWaits * MAX_SLEEP, + // which defaults to 60s * 100 that is 6000s that is 2hours. + val defaultMaxSleep = hiveConf.getTimeVar( + HiveConf.ConfVars.HIVE_LOCK_SLEEP_BETWEEN_RETRIES, TimeUnit.MILLISECONDS) + val MAX_SLEEP = Math.max(15000, defaultMaxSleep) + val maxNumWaits: Int = Math.max(0, hiveConf.getIntVar(HiveConf.ConfVars.HIVE_LOCK_NUMRETRIES)) + def backoff(): Unit = { + nextSleep *= 2 + if (nextSleep > MAX_SLEEP) nextSleep = MAX_SLEEP + try + Thread.sleep(nextSleep) + catch { + case _: InterruptedException => + + } + } + try { + var res: LockResponse = client.lock(lockReq) + // link lockId to queryId + var numRetries: Int = 0 + while (res.getState == LockState.WAITING && numRetries < maxNumWaits) { + numRetries += 1 + backoff() + res = client.checkLock(res.getLockid) + } + if (res.getState != LockState.ACQUIRED) { + throw HiveAcidErrors.couldNotAcquireLockException(state = res.getState.name()) + } + } catch { + case e: TException => + logWarning("Unable to acquire lock", e) + throw HiveAcidErrors.couldNotAcquireLockException(e) + } + } + + lock(createLockRequest()) + } + + private class HeartbeatRunnable() extends Runnable { + private def send(txn: HiveAcidTxn): Unit = { + try { + // Does not matter if txn is already ended + val resp = heartBeaterClient.heartbeatTxnRange(txn.txnId, txn.txnId) + if (resp.getAborted.nonEmpty || resp.getNosuch.nonEmpty) { + logError(s"Heartbeat failure for transaction id: ${txn.txnId} : ${resp.toString}." + + s"Aborting...") + } else { + logDebug(s"Heartbeat sent for txnId: ${txn.txnId}") + } + } catch { + // No action required because if heartbeat doesn't go for some time, transaction will be + // aborted by HMS automatically. We can abort the transaction here also if we are not + // able to send heartbeat for some time + case e: TException => + logWarning(s"Failure to heartbeat for txnId: ${txn.txnId}", e) + } + } + + override def run(): Unit = { + if (shutdownInitiated.get()) { + return + } + + // Close the txnManager + if (SqlUtils.hasSparkStopped(sparkSession)) { + close() + return + } + + if (HiveAcidTxnManager.activeTxns.nonEmpty) { + HiveAcidTxnManager.activeTxns.foreach { + case (_, txn) => send(txn) + } + } + } + } + + class HeartBeaterThread(val target: Runnable, val name: String) extends Thread(target, name) { + setDaemon(true) + } +} + +protected[hiveacid] object HiveAcidTxnManager { + // Maintain activeTxns inside txnManager instead of HiveAcidTxn + // object for it to be accessible to back ground thread running + // inside HiveAcidTxnManager. + protected val activeTxns = new scala.collection.mutable.HashMap[Long, HiveAcidTxn]() + def getTxn(txnId: Long): Option[HiveAcidTxn] = HiveAcidTxnManager.activeTxns.get(txnId) +} diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/util/.gitignore b/src/main/scala/com/qubole/spark/hiveacid/util/.gitignore similarity index 100% rename from src/main/scala/com/qubole/spark/datasources/hiveacid/util/.gitignore rename to src/main/scala/com/qubole/spark/hiveacid/util/.gitignore diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/util/SerializableConfiguration.scala b/src/main/scala/com/qubole/spark/hiveacid/util/SerializableConfiguration.scala similarity index 85% rename from src/main/scala/com/qubole/spark/datasources/hiveacid/util/SerializableConfiguration.scala rename to src/main/scala/com/qubole/spark/hiveacid/util/SerializableConfiguration.scala index 74393a0..ad8302d 100644 --- a/src/main/scala/com/qubole/spark/datasources/hiveacid/util/SerializableConfiguration.scala +++ b/src/main/scala/com/qubole/spark/hiveacid/util/SerializableConfiguration.scala @@ -17,14 +17,17 @@ * limitations under the License. */ -package com.qubole.spark.datasources.hiveacid.util +package com.qubole.spark.hiveacid.util import java.io.{ObjectInputStream, ObjectOutputStream} import org.apache.hadoop.conf.Configuration -import org.apache.spark.util.Utils -class SerializableConfiguration(@transient var value: Configuration) extends Serializable { +/** + * Utility class to make configuration object serializable + */ +private[hiveacid] class SerializableConfiguration(@transient var value: Configuration) + extends Serializable { private def writeObject(out: ObjectOutputStream): Unit = Util.tryOrIOException { out.defaultWriteObject() value.write(out) diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/util/SerializableWritable.scala b/src/main/scala/com/qubole/spark/hiveacid/util/SerializableWritable.scala similarity index 86% rename from src/main/scala/com/qubole/spark/datasources/hiveacid/util/SerializableWritable.scala rename to src/main/scala/com/qubole/spark/hiveacid/util/SerializableWritable.scala index 04485fd..5b80e10 100644 --- a/src/main/scala/com/qubole/spark/datasources/hiveacid/util/SerializableWritable.scala +++ b/src/main/scala/com/qubole/spark/hiveacid/util/SerializableWritable.scala @@ -17,7 +17,7 @@ * limitations under the License. */ -package com.qubole.spark.datasources.hiveacid.util +package com.qubole.spark.hiveacid.util import java.io._ @@ -25,11 +25,11 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.ObjectWritable import org.apache.hadoop.io.Writable -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.Utils - -@DeveloperApi -class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable { +/** + * Utility class to make a Writable serializable + */ +private[hiveacid] class SerializableWritable[T <: Writable](@transient var t: T) + extends Serializable { def value: T = t diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/util/Util.scala b/src/main/scala/com/qubole/spark/hiveacid/util/Util.scala similarity index 52% rename from src/main/scala/com/qubole/spark/datasources/hiveacid/util/Util.scala rename to src/main/scala/com/qubole/spark/hiveacid/util/Util.scala index 3450d30..7d35442 100644 --- a/src/main/scala/com/qubole/spark/datasources/hiveacid/util/Util.scala +++ b/src/main/scala/com/qubole/spark/hiveacid/util/Util.scala @@ -17,25 +17,15 @@ * limitations under the License. */ -package com.qubole.spark.datasources.hiveacid.util +package com.qubole.spark.hiveacid.util import java.io.IOException -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapred.{FileSplit, InputSplit} -import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit - import org.apache.spark.internal.Logging import scala.util.control.NonFatal - -object SplitFileSystemType extends Enumeration { - type SplitFileSystemType = Value - val NONE, BLOB_STORE= Value -} - -object Util extends Logging { +private[hiveacid] object Util extends Logging { def classForName(className: String, loadShaded: Boolean = false): Class[_] = { val classToLoad = if (loadShaded) { @@ -46,12 +36,17 @@ object Util extends Logging { Class.forName(classToLoad, true, Thread.currentThread().getContextClassLoader) } - val fileSystemSchemes: List[String] = List("s3", "s3n", "s3a", "wasb", "adl", - "oraclebmc", "oci") - + /** + * Detect whether this thread might be executing a shutdown hook. Will always return true if + * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. + * if System.exit was just called by a concurrent thread). + * + * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing + * an IllegalStateException. + */ def inShutdown(): Boolean = { try { - val hook = new Thread { + val hook: Thread = new Thread { override def run() {} } // scalastyle:off runtimeaddshutdownhook @@ -59,44 +54,11 @@ object Util extends Logging { // scalastyle:on runtimeaddshutdownhook Runtime.getRuntime.removeShutdownHook(hook) } catch { - case ise: IllegalStateException => return true + case _: IllegalStateException => return true } false } - def getSplitScheme(path: String) : SplitFileSystemType.SplitFileSystemType = { - if(path == null) { - return SplitFileSystemType.NONE - } - if (fileSystemSchemes.contains(new Path(path).toUri.getScheme.toLowerCase)) { - return SplitFileSystemType.BLOB_STORE - } - return SplitFileSystemType.NONE - } - - def getSplitScheme[T >: InputSplit](split: T) : SplitFileSystemType.SplitFileSystemType = { - split match { - case f: FileSplit => - if (fileSystemSchemes.contains( - split.asInstanceOf[FileSplit].getPath.toUri.getScheme.toLowerCase)) { - SplitFileSystemType.BLOB_STORE - } else { - SplitFileSystemType.NONE - } - // When wholeTextFiles is used for reading multiple files in one go, - // the split has multiple paths in it. We get the scheme of the first - // path and use that for the rest too. - case cf : CombineFileSplit => - if (fileSystemSchemes.contains( - split.asInstanceOf[CombineFileSplit].getPath(0).toUri.getScheme.toLowerCase)) { - SplitFileSystemType.BLOB_STORE - } else { - SplitFileSystemType.NONE - } - case _ => SplitFileSystemType.NONE - } - } - def tryOrIOException[T](block: => T): T = { try { block @@ -109,6 +71,4 @@ object Util extends Logging { throw new IOException(e) } } - - } diff --git a/src/main/scala/com/qubole/spark/hiveacid/writer/TableWriter.scala b/src/main/scala/com/qubole/spark/hiveacid/writer/TableWriter.scala new file mode 100644 index 0000000..9fa60a7 --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/writer/TableWriter.scala @@ -0,0 +1,185 @@ +/* + * Copyright 2019 Qubole, Inc. All rights reserved. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid.writer + +import scala.collection.JavaConverters._ +import scala.language.implicitConversions + +import com.qubole.spark.hiveacid._ +import com.qubole.spark.hiveacid.hive.HiveAcidMetadata +import com.qubole.spark.hiveacid.writer.hive.{HiveAcidFullAcidWriter, HiveAcidInsertOnlyWriter, HiveAcidWriterOptions} +import com.qubole.spark.hiveacid.transaction._ +import com.qubole.spark.hiveacid.util.SerializableConfiguration + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.command.AlterTableAddPartitionCommand +import org.apache.spark.sql.types.StructType + +/** + * Performs eager write of a dataframe df to a hive acid table based on operationType + * @param sparkSession - Spark session + * @param curTxn - Transaction object to acquire locks. + * @param hiveAcidMetadata - Hive acid table where we want to write dataframe + */ +private[hiveacid] class TableWriter(sparkSession: SparkSession, + curTxn: HiveAcidTxn, + hiveAcidMetadata: HiveAcidMetadata) extends Logging { + + private def getSchema(operationType: HiveAcidOperation.OperationType): StructType = { + val expectRowIdsInDataFrame = operationType match { + case HiveAcidOperation.INSERT_OVERWRITE | HiveAcidOperation.INSERT_INTO => false + case HiveAcidOperation.DELETE | HiveAcidOperation.UPDATE => true + case _ => throw HiveAcidErrors.invalidOperationType(operationType.toString) + } + + if (expectRowIdsInDataFrame) { + hiveAcidMetadata.tableSchemaWithRowId + } else { + hiveAcidMetadata.tableSchema + } + } + + private def getColumns(operationType: HiveAcidOperation.OperationType, + df: DataFrame): (Seq[Attribute], Array[Attribute], Seq[Attribute]) = { + + val columnNames = getSchema(operationType).fields.map(_.name) + + val allColumns = df.queryExecution.optimizedPlan.output.zip(columnNames).map { + case (attr, columnName) => + attr.withName(columnName) + } + + val allColumnNameToAttrMap = allColumns.map(attr => attr.name -> attr).toMap + + val partitionColumns = hiveAcidMetadata.partitionSchema.fields.map( + field => allColumnNameToAttrMap(field.name)) + + val dataColumns = allColumns.filterNot(partitionColumns.contains) + + (allColumns, partitionColumns, dataColumns) + } + + /** + * Common utility function to process all types of operations insert/update/delete + * for the hive acid table + * @param operationType type of operation. + * @param df data frame to be written into the table. + */ + def process(operationType: HiveAcidOperation.OperationType, + df: DataFrame): Unit = { + + val hadoopConf = sparkSession.sessionState.newHadoopConf() + + val (allColumns, partitionColumns, dataColumns) = getColumns(operationType, df) + + try { + + // FIXME: IF we knew the partition then we should + // only lock that partition. + curTxn.acquireLocks(hiveAcidMetadata, operationType, Seq()) + + // Create Snapshot !!! + val curSnapshot = HiveAcidTxn.createSnapshot(curTxn, hiveAcidMetadata) + + val writerOptions = new WriterOptions(curSnapshot.currentWriteId, + operationType, + new SerializableConfiguration(hadoopConf), + getSchema(operationType), + dataColumns, + partitionColumns, + allColumns, + sparkSession.sessionState.conf.sessionLocalTimeZone + ) + + val isFullAcidTable = hiveAcidMetadata.isFullAcidTable + + val hiveAcidWriterOptions = HiveAcidWriterOptions.get(hiveAcidMetadata, writerOptions) + + // This RDD is serialized and sent for distributed execution. + // All the access object in this needs to be serializable. + val processRddPartition = new (Iterator[InternalRow] => Seq[TablePartitionSpec]) with + Serializable { + override def apply(iterator: Iterator[InternalRow]): Seq[TablePartitionSpec] = { + val writer = if (isFullAcidTable) { + new HiveAcidFullAcidWriter(writerOptions, hiveAcidWriterOptions) + } else { + new HiveAcidInsertOnlyWriter(writerOptions, hiveAcidWriterOptions) + } + iterator.foreach { row => writer.process(row) } + writer.close() + writer.partitionsTouched() + } + } + + val resultRDD = + operationType match { + // Deleted rowId needs to be in the same bucketed file name + // as original row. To achieve that repartition it based on + // the rowId.bucketId. After this shuffle partitionId maps + // 1-to-1 with bucketId. + case HiveAcidOperation.DELETE | HiveAcidOperation.UPDATE => + df.sort("rowId.bucketId") + .toDF.queryExecution.executedPlan.execute() + case HiveAcidOperation.INSERT_OVERWRITE | HiveAcidOperation.INSERT_INTO => + df.queryExecution.executedPlan.execute() + case unknownOperation => + throw HiveAcidErrors.invalidOperationType(unknownOperation.toString) + } + + val touchedPartitions = sparkSession.sparkContext.runJob( + resultRDD, processRddPartition + ).flatten.toSet + + // Add new partition to table metadata under the transaction. + val existingPartitions = hiveAcidMetadata.getRawPartitions() + .map(_.getSpec) + .map(_.asScala.toMap) + + val newPartitions = touchedPartitions -- existingPartitions + + logDebug(s"existing partitions: ${touchedPartitions.size}, " + + s"partitions touched: ${touchedPartitions.size}, " + + s"new partitions to add to metastore: ${newPartitions.size}") + + if (newPartitions.nonEmpty) { + AlterTableAddPartitionCommand( + new TableIdentifier(hiveAcidMetadata.tableName, Option(hiveAcidMetadata.dbName)), + newPartitions.toSeq.map(p => (p, None)), + ifNotExists = true).run(sparkSession) + } + + // FIXME: Add the notification events for replication et al. + // + + logDebug("new partitions added successfully") + + } catch { + case e: Exception => + logError("Exception", e) + throw e + } + } +} + + diff --git a/src/main/scala/com/qubole/spark/hiveacid/writer/Writer.scala b/src/main/scala/com/qubole/spark/hiveacid/writer/Writer.scala new file mode 100644 index 0000000..aba329b --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/writer/Writer.scala @@ -0,0 +1,29 @@ +/* + * Copyright 2019 Qubole, Inc. All rights reserved. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid.writer + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec + +private[hiveacid] trait Writer { + def process(row: InternalRow): Unit + def close(): Unit + def partitionsTouched(): Seq[TablePartitionSpec] +} \ No newline at end of file diff --git a/src/main/scala/com/qubole/spark/hiveacid/writer/WriterOptions.scala b/src/main/scala/com/qubole/spark/hiveacid/writer/WriterOptions.scala new file mode 100644 index 0000000..c872f95 --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/writer/WriterOptions.scala @@ -0,0 +1,37 @@ +/* + * Copyright 2019 Qubole, Inc. All rights reserved. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid.writer + +import com.qubole.spark.hiveacid.HiveAcidOperation +import com.qubole.spark.hiveacid.util.SerializableConfiguration +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.types.StructType + +/** + * Writer options which will be serialized and sent to each executor + */ +private[hiveacid] class WriterOptions(val currentWriteId: Long, + val operationType: HiveAcidOperation.OperationType, + val serializableHadoopConf: SerializableConfiguration, + val rowIDSchema: StructType, + val dataColumns: Seq[Attribute], + val partitionColumns: Seq[Attribute], + val allColumns: Seq[Attribute], + val timeZoneId: String) extends Serializable diff --git a/src/main/scala/com/qubole/spark/hiveacid/writer/hive/HiveAcidWriter.scala b/src/main/scala/com/qubole/spark/hiveacid/writer/hive/HiveAcidWriter.scala new file mode 100644 index 0000000..8896773 --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/writer/hive/HiveAcidWriter.scala @@ -0,0 +1,455 @@ +/* + * Copyright 2019 Qubole, Inc. All rights reserved. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid.writer.hive + +import java.util.Properties + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import com.qubole.shaded.hadoop.hive.ql.exec.FileSinkOperator.RecordWriter +import com.qubole.shaded.hadoop.hive.ql.exec.Utilities +import com.qubole.shaded.hadoop.hive.ql.io.{BucketCodec, HiveFileFormatUtils, RecordIdentifier, RecordUpdater, _} +import com.qubole.shaded.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import com.qubole.shaded.hadoop.hive.serde2.{Deserializer, SerDeUtils} +import com.qubole.shaded.hadoop.hive.serde2.Serializer +import com.qubole.shaded.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory, ObjectInspectorUtils, StructObjectInspector} +import com.qubole.shaded.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption +import com.qubole.spark.hiveacid.{HiveAcidErrors, HiveAcidOperation} +import com.qubole.spark.hiveacid.util.Util +import com.qubole.spark.hiveacid.writer.{Writer, WriterOptions} +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred.{JobConf, Reporter} +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils +import org.apache.spark.sql.catalyst.expressions.{Cast, Concat, Expression, Literal, ScalaUDF, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.hive.Hive3Inspectors +import org.apache.spark.sql.types.StringType + +abstract private[writer] class HiveAcidWriter(val options: WriterOptions, + val HiveAcidOptions: HiveAcidWriterOptions) + extends Writer with Logging { + + // Takes as input partition columns returns expression to + // create partition path. See udf `getPartitionPathString` + // input = p1, p2, p3 + // output = Seq(Expr(p1=*), Literal(/), Expr(p2=?), Literal('/'), Expr(p3=?)) + private val partitionPathExpression: Expression = Concat( + options.partitionColumns.zipWithIndex.flatMap { case (c, i) => + val partitionName = ScalaUDF( + ExternalCatalogUtils.getPartitionPathString _, + StringType, + Seq(Literal(c.name), Cast(c, StringType, Option(options.timeZoneId))), + Seq(true, true)) + if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) + }) + + private val getPartitionPath: InternalRow => String = { + val proj = UnsafeProjection.create(Seq(partitionPathExpression), + options.partitionColumns) + row => proj(row).getString(0) + } + + private val partitionsTouchedSet = scala.collection.mutable.HashSet[TablePartitionSpec]() + def partitionsTouched(): Seq[TablePartitionSpec] = partitionsTouchedSet.toSeq + + // Utility functions for extracting partition and data parts of row + // + protected val getDataValues: InternalRow => UnsafeRow = { + val proj = UnsafeProjection.create(options.dataColumns, options.allColumns) + row => proj(row) + } + + protected val getPartitionValues: InternalRow => UnsafeRow = { + val proj = UnsafeProjection.create(options.partitionColumns, + options.allColumns) + row => proj(row) + } + + protected val jobConf: JobConf = { + val hConf = options.serializableHadoopConf.value + new JobConf(hConf) + } + + protected val sparkHiveRowConverter = new SparkHiveRowConverter(options, HiveAcidOptions, jobConf) + + // Cache of writers + protected val writers: mutable.Map[(String, Int, Int), Any] = scala.collection.mutable.Map[(String, Int, Int), Any]() + + lazy protected val taskId: Int = + Utilities.getTaskIdFromFilename(TaskContext.get.taskAttemptId().toString).toInt + + protected def getOrCreateWriter(partitionRow: InternalRow, acidBucketId: Int): Any = { + + val partitionBasePath = if (options.partitionColumns.isEmpty) { + new Path(HiveAcidOptions.rootPath) + } else { + val path = getPartitionPath(partitionRow) + partitionsTouchedSet.add(PartitioningUtils.parsePathFragment(path)) + new Path(HiveAcidOptions.rootPath, path) + } + + writers.getOrElseUpdate((partitionBasePath.toUri.toString, taskId, acidBucketId), + createWriter(partitionBasePath, acidBucketId)) + } + + protected def createWriter(path: Path, acidBucketId: Int): Any = {} + + // Cache structure to store row + protected val hiveRow = new Array[Any](sparkHiveRowConverter.numFields) + + lazy protected val fileSinkConf: FileSinkDesc = { + val fileSinkConf = HiveAcidOptions.getFileSinkDesc + fileSinkConf.setDirName(new Path(HiveAcidOptions.rootPath)) + fileSinkConf + } +} + +/** + * This class is responsible for writing a InternalRow into a Full-ACID table + * This can handle InsertInto/InsertOverwrite/Update/Delete operations + * It has method `process` which takes 1 InternalRow and processes it based on + * OperationType (insert/update/delete etc). + * + * It is assumed that the InternalRow passed to process is in the right form + * i.e. + * for InsertInto/InsertOverwrite operations, row is expected to contain data and no row ids + * for Delete, row is expected to contain rowId + * for Update, row is expected to contain data as well as row Id + * + * BucketID Conundrum: In acid table single bucket cannot have multiple files. + * + * When performing inserts for the first time, bucketId needs to assigned such that + * only 1 tasks writes the data to 1 bucket. Scheme used currently is to use shuffle + * partitionID as bucketID. This scheme works even for task retry as retried task + * overwrite already half written file, but has a problem of skew as the shuffle + * partitionId is not based on data but on number of splits, it has bias towards + * adding more data on lower number bucket. Ideally the data should been equally + * distributed across multiple inserts. + * + * When performing delete/updates (updates are inserts + deletes) inside delete + * delta directory bucket number in file name needs to match that of base and + * delta file for which the delete file is being written. This is to be able + * to prune files by name itself (penalty paid at the time of write). For handling + * it data is repartitioned on rowId.bucketId before writing delete delta file. + * (see [[com.qubole.spark.hiveacid.writer.TableWriter]]) + * + * For Bucketed tables + * + * Invariant of however many splits that may have been created at the source the + * data into single bucket needs to be written by single task. Hence before writing + * the data is Same bucket cannot have multiple files + * + * @param options - writer options to use + * @param HiveAcidOptions - Hive3 related writer options. + */ +private[writer] class HiveAcidFullAcidWriter(options: WriterOptions, + HiveAcidOptions: HiveAcidWriterOptions) + extends HiveAcidWriter(options, HiveAcidOptions) with Logging { + + private lazy val rowIdColNum = options.operationType match { + case HiveAcidOperation.INSERT_INTO | HiveAcidOperation.INSERT_OVERWRITE => + -1 + case HiveAcidOperation.UPDATE | HiveAcidOperation.DELETE => + 0 + case x => + throw new RuntimeException(s"Invalid write operation $x") + } + + override protected def createWriter(path: Path, acidBucketId: Int): Any = { + + val tableDesc = HiveAcidOptions.getFileSinkDesc.getTableInfo + + val recordUpdater = HiveFileFormatUtils.getAcidRecordUpdater( + jobConf, + tableDesc, + acidBucketId, + HiveAcidOptions.getFileSinkDesc, + path, + sparkHiveRowConverter.getObjectInspector, + Reporter.NULL, + rowIdColNum) + + val acidOutputFormatOptions = new AcidOutputFormat.Options(jobConf) + .writingBase(options.operationType == HiveAcidOperation.INSERT_OVERWRITE) + .bucket(acidBucketId) + .minimumWriteId(fileSinkConf.getTableWriteId) + .maximumWriteId(fileSinkConf.getTableWriteId) + .statementId(fileSinkConf.getStatementId) + + val (createDelta, createDeleteDelta) = options.operationType match { + case HiveAcidOperation.INSERT_INTO | HiveAcidOperation.INSERT_OVERWRITE => (true, false) + case HiveAcidOperation.UPDATE => (true, true) + case HiveAcidOperation.DELETE => (false, true) + case unknownOperation => throw HiveAcidErrors.invalidOperationType(unknownOperation.toString) + } + + val fs = path.getFileSystem(jobConf) + + def createVersionFile(acidOptions: AcidOutputFormat.Options): Unit = { + try { + AcidUtils.OrcAcidVersion.writeVersionFile( + AcidUtils.createFilename(path, acidOptions).getParent, fs) + } catch { + case _: Exception => + logError("Version file already found") + case scala.util.control.NonFatal(_) => + logError("Version file already found - non fatal") + case _: Throwable => + logError("Version file already found - shouldn't be caught") + } + } + + if (createDelta) { + createVersionFile(acidOutputFormatOptions) + } + + if (createDeleteDelta) { + createVersionFile(acidOutputFormatOptions.writingDeleteDelta(true)) + } + recordUpdater + } + + private def getBucketID(dataRow: InternalRow): Int = { + // FIXME: Deal with bucketed table. + val bucketedTable = false + if (bucketedTable) { + // getBucketIdFromCol(partitionRow) + 0 + } else { + options.operationType match { + case HiveAcidOperation.INSERT_INTO | HiveAcidOperation.INSERT_OVERWRITE => + Utilities.getTaskIdFromFilename(TaskContext.getPartitionId().toString) + .toInt + case HiveAcidOperation.DELETE | HiveAcidOperation.UPDATE => + val rowID = dataRow.get(rowIdColNum, options.rowIDSchema) + // FIXME: Currently hard coding codec as V1 and also bucket ordinal as 1. + BucketCodec.V1.decodeWriterId(rowID.asInstanceOf[UnsafeRow].getInt(1)) + case x => + throw new RuntimeException(s"Invalid write operation $x") + } + } + } + + /** + * Process an Spark InternalRow + * + * @param row row to be processed + */ + def process(row: InternalRow): Unit = { + // Identify the partitionColumns and nonPartitionColumns in row + val partitionColRow = getPartitionValues(row) + val dataColRow = getDataValues(row) + + // Get the recordWriter for this partitionedRow + val recordUpdater = + getOrCreateWriter(partitionColRow, getBucketID(dataColRow)).asInstanceOf[RecordUpdater] + + val recordValue = sparkHiveRowConverter.toHiveRow(dataColRow, hiveRow) + + options.operationType match { + case HiveAcidOperation.INSERT_INTO | HiveAcidOperation.INSERT_OVERWRITE => + recordUpdater.insert(options.currentWriteId, recordValue) + case HiveAcidOperation.UPDATE => + recordUpdater.update(options.currentWriteId, recordValue) + case HiveAcidOperation.DELETE => + recordUpdater.delete(options.currentWriteId, recordValue) + case x => + throw new RuntimeException(s"Invalid write operation $x") + } + } + + def close(): Unit = { + writers.foreach( x => try { + // TODO: Seems the boolean value passed into close does not matter. + x._2.asInstanceOf[RecordUpdater].close(false) + } + catch { + case e: Exception => + logError("Unable to close " + x._2 + " due to: " + e.getMessage) + }) + } +} + +/** + * This class is responsible for writing a InternalRow into a insert-only table + * This can handle InsertInto/InsertOverwrite. This does not support Update/Delete operations + * It has method `process` which takes 1 InternalRow and processes it based on + * OperationType. row is expected to contain data and no row ids + * + * @param options writer options to use + * @param HiveAcidOptions hive3 specific options, which is passed into underlying hive3 API + */ +private[writer] class HiveAcidInsertOnlyWriter(options: WriterOptions, + HiveAcidOptions: HiveAcidWriterOptions) + extends HiveAcidWriter(options, HiveAcidOptions) { + + override protected def createWriter(path: Path, acidBucketId: Int): Any = { + val outputClass = sparkHiveRowConverter.serializer.getSerializedClass + + val acidOutputFormatOptions = new AcidOutputFormat.Options(jobConf) + .writingBase(options.operationType == HiveAcidOperation.INSERT_OVERWRITE) + .bucket(acidBucketId) + .minimumWriteId(fileSinkConf.getTableWriteId) + .maximumWriteId(fileSinkConf.getTableWriteId) + .statementId(fileSinkConf.getStatementId) + + // FIXME: Hack to remove bucket prefix for Insert only table. + var fullPathStr = AcidUtils.createFilename(path, + acidOutputFormatOptions).toString.replace("bucket_", "") + + fullPathStr += "_" + taskId + + logDebug(s" $fullPathStr") + + HiveFileFormatUtils.getHiveRecordWriter( + jobConf, + sparkHiveRowConverter.tableDesc, + outputClass, + HiveAcidOptions.getFileSinkDesc, + new Path(fullPathStr), + Reporter.NULL) + } + /** + * Process an Spark InternalRow + * @param row row to be processed + */ + override def process(row: InternalRow): Unit = { + // Identify the partitionColumns and nonPartitionColumns in row + val partitionColRow = getPartitionValues(row) + val dataColRow = getDataValues(row) + + // FIXME: Find the bucket id based on some sort hash on the data row + val bucketId = 0 + + // Get the recordWriter for this partitionedRow + val writer = getOrCreateWriter(partitionColRow, bucketId) + + val recordValue = + sparkHiveRowConverter.serialize(sparkHiveRowConverter.toHiveRow(dataColRow, hiveRow)) + + options.operationType match { + case HiveAcidOperation.INSERT_INTO | HiveAcidOperation.INSERT_OVERWRITE => + writer.asInstanceOf[RecordWriter].write(recordValue) + case x => + throw new RuntimeException(s"Invalid write operation $x") + } + } + + def close(): Unit = { + writers.foreach( x => try { + // TODO: Seems the boolean value passed into close does not matter. + x._2.asInstanceOf[RecordWriter].close(false) + } + catch { + case e: Exception => + logError("Unable to close " + x._2 + " due to: " + e.getMessage) + }) + } + +} + +/** + * Utility class to convert a spark InternalRow to a row required by Hive + * @param options - hive acid writer options + * @param jobConf - job conf + */ +private[hive] class SparkHiveRowConverter(options: WriterOptions, + HiveAcidOptions: HiveAcidWriterOptions, + jobConf: JobConf) extends Hive3Inspectors { + + val tableDesc: TableDesc = HiveAcidOptions.getFileSinkDesc.getTableInfo + + // NB: Can't use tableDesc.getDeserializer as it uses Reflection + // internally which doesn't work because of shading. So copied its logic + lazy val serializer: Serializer = { + val serializer = Util.classForName(tableDesc.getSerdeClassName, + loadShaded = true).asInstanceOf[Class[Serializer]].newInstance() + serializer.initialize(jobConf, tableDesc.getProperties) + serializer + } + + lazy val deserializer: Deserializer = { + val deserializer = Util.classForName(tableDesc.getSerdeClassName, + loadShaded = true).asInstanceOf[Class[Deserializer]].newInstance() + SerDeUtils.initializeSerDe(deserializer, jobConf, tableDesc.getProperties, + null.asInstanceOf[Properties]) + deserializer + } + + // Object Inspector Objects + // + private val recIdInspector = RecordIdentifier.StructInfo.oi + + private val oIWithoutRowId = ObjectInspectorUtils + .getStandardObjectInspector( + deserializer.getObjectInspector, + ObjectInspectorCopyOption.JAVA) + .asInstanceOf[StructObjectInspector] + + + private val oIWithRowId = { + val dataStructFields = asScalaIteratorConverter( + oIWithoutRowId.getAllStructFieldRefs.iterator).asScala.toSeq + + val newFieldNameSeq = Seq("rowId") ++ dataStructFields.map(_.getFieldName) + val newOISeq = Seq(recIdInspector) ++ + dataStructFields.map(_.getFieldObjectInspector) + ObjectInspectorFactory.getStandardStructObjectInspector( + newFieldNameSeq.asJava, newOISeq.asJava + ) + } + + val objectInspector: StructObjectInspector = options.operationType match { + case HiveAcidOperation.INSERT_INTO | HiveAcidOperation.INSERT_OVERWRITE => + oIWithoutRowId + case HiveAcidOperation.DELETE | HiveAcidOperation.UPDATE => + oIWithRowId + case x => + throw new RuntimeException(s"Invalid write operation $x") + } + + private val fieldOIs = + objectInspector.getAllStructFieldRefs.asScala.map(_.getFieldObjectInspector).toArray + + def getObjectInspector: StructObjectInspector = objectInspector + + def numFields: Int = fieldOIs.length + + def serialize(hiveRow: Array[Any]): Writable = { + serializer.serialize(hiveRow, objectInspector.asInstanceOf[ObjectInspector]) + } + + def toHiveRow(sparkRow: InternalRow, hiveRow: Array[Any]): Array[Any] = { + val dataTypes = options.dataColumns.map(_.dataType).toArray + val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt) } + + var i = 0 + while (i < fieldOIs.length) { + hiveRow(i) = if (sparkRow.isNullAt(i)) null else wrappers(i)(sparkRow.get(i, dataTypes(i))) + i += 1 + } + hiveRow + } +} diff --git a/src/main/scala/com/qubole/spark/hiveacid/writer/hive/HiveAcidWriterOptions.scala b/src/main/scala/com/qubole/spark/hiveacid/writer/hive/HiveAcidWriterOptions.scala new file mode 100644 index 0000000..ed475dd --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/writer/hive/HiveAcidWriterOptions.scala @@ -0,0 +1,52 @@ +/* + * Copyright 2019 Qubole, Inc. All rights reserved. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid.writer.hive + +import com.qubole.shaded.hadoop.hive.ql.plan.FileSinkDesc +import com.qubole.spark.hiveacid.HiveAcidOperation +import com.qubole.spark.hiveacid.hive.HiveAcidMetadata +import com.qubole.spark.hiveacid.writer.WriterOptions +import org.apache.hadoop.fs.Path + +private[writer] class HiveAcidWriterOptions(val rootPath: String, + fileSinkDesc: FileSinkDesc) extends Serializable { + lazy val getFileSinkDesc: FileSinkDesc = { + fileSinkDesc.setDirName(new Path(rootPath)) + fileSinkDesc + } +} + +private[writer] object HiveAcidWriterOptions { + def get(hiveAcidMetadata: HiveAcidMetadata, + options: WriterOptions): HiveAcidWriterOptions = { + lazy val fileSinkDescriptor: FileSinkDesc = { + val fileSinkDesc: FileSinkDesc = new FileSinkDesc() + fileSinkDesc.setTableInfo(hiveAcidMetadata.tableDesc) + fileSinkDesc.setTableWriteId(options.currentWriteId) + if (options.operationType == HiveAcidOperation.INSERT_OVERWRITE) { + fileSinkDesc.setInsertOverwrite(true) + } + fileSinkDesc + } + new HiveAcidWriterOptions(rootPath = hiveAcidMetadata.rootPath.toUri.toString, + fileSinkDesc = fileSinkDescriptor) + } + +} diff --git a/src/main/scala/org/apache/spark/sql/SqlUtils.scala b/src/main/scala/org/apache/spark/sql/SqlUtils.scala new file mode 100644 index 0000000..63c14b8 --- /dev/null +++ b/src/main/scala/org/apache/spark/sql/SqlUtils.scala @@ -0,0 +1,51 @@ +/* + * Copyright 2019 Qubole, Inc. All rights reserved. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} + +object SqlUtils { + def convertToDF(sparkSession: SparkSession, plan : LogicalPlan): DataFrame = { + Dataset.ofRows(sparkSession, plan) + } + + def resolveReferences(sparkSession: SparkSession, + expr: Expression, + planContaining: LogicalPlan): Expression = { + val newPlan = FakeLogicalPlan(expr, Seq(planContaining)) + sparkSession.sessionState.analyzer.execute(newPlan) match { + case FakeLogicalPlan(resolvedExpr: Expression, _) => + // Return even if it did not successfully resolve + resolvedExpr + case _ => + expr + // This is unexpected + } + } + def hasSparkStopped(sparkSession: SparkSession): Boolean = { + sparkSession.sparkContext.stopped.get() + } +} + +case class FakeLogicalPlan(expr: Expression, children: Seq[LogicalPlan]) + extends LogicalPlan { + override def output: Seq[Attribute] = children.foldLeft(Seq[Attribute]())((out, child) => out ++ child.output) +} diff --git a/src/main/scala/com/qubole/spark/datasources/hiveacid/rdd/Hive3Inspectors.scala b/src/main/scala/org/apache/spark/sql/hive/Hive3Inspectors.scala similarity index 98% rename from src/main/scala/com/qubole/spark/datasources/hiveacid/rdd/Hive3Inspectors.scala rename to src/main/scala/org/apache/spark/sql/hive/Hive3Inspectors.scala index cbed37b..968c393 100644 --- a/src/main/scala/com/qubole/spark/datasources/hiveacid/rdd/Hive3Inspectors.scala +++ b/src/main/scala/org/apache/spark/sql/hive/Hive3Inspectors.scala @@ -17,27 +17,32 @@ * limitations under the License. */ -package com.qubole.spark.datasources.hiveacid.rdd +package org.apache.spark.sql.hive import java.lang.reflect.{ParameterizedType, Type, WildcardType} -import java.sql.Date -import java.sql.Timestamp import scala.collection.JavaConverters._ -import org.apache.hadoop.{io => hadoopIo} + import com.qubole.shaded.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} import com.qubole.shaded.hadoop.hive.serde2.{io => hiveIo} import com.qubole.shaded.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} import com.qubole.shaded.hadoop.hive.serde2.objectinspector.primitive._ import com.qubole.shaded.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory} -import com.qubole.spark.datasources.hiveacid.AnalysisException +import com.qubole.spark.hiveacid.AnalysisException +import org.apache.hadoop.{io => hadoopIo} + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types + import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +/** + * This class is similar to org.apache.spark.sql.hive.HiveInspectors. + * Changes are made here to make it work with Hive3 + */ trait Hive3Inspectors { def javaTypeToDataType(clz: Type): DataType = clz match { @@ -136,7 +141,7 @@ trait Hive3Inspectors { case _: StringObjectInspector if x.preferWritable() => withNullSafe(o => getStringWritable(o)) case _: StringObjectInspector => - withNullSafe(o => o.asInstanceOf[UTF8String].toString()) + withNullSafe(o => o.asInstanceOf[UTF8String].toString) case _: IntObjectInspector if x.preferWritable() => withNullSafe(o => getIntWritable(o)) case _: IntObjectInspector => @@ -366,7 +371,7 @@ trait Hive3Inspectors { .toArray val constant = new GenericArrayData(values) _ => constant - case poi: VoidObjectInspector => + case _: VoidObjectInspector => _ => null // always be null for void object inspector case pi: PrimitiveObjectInspector => pi match { // We think HiveVarchar/HiveChar is also a String @@ -453,9 +458,7 @@ trait Hive3Inspectors { data: Any => { if (data != null) { toCatalystDecimal(x, data) - } else { - null - } + } else null } case x: BinaryObjectInspector if x.preferWritable() => data: Any => { @@ -464,12 +467,10 @@ trait Hive3Inspectors { // In order to keep backward-compatible, we have to copy the // bytes with old apis val bw = x.getPrimitiveWritableObject(data) - val result = new Array[Byte](bw.getLength()) - System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength()) + val result = new Array[Byte](bw.getLength) + System.arraycopy(bw.getBytes, 0, result, 0, bw.getLength) result - } else { - null - } + } else null } case x: DateObjectInspector if x.preferWritable() => data: Any => { @@ -483,7 +484,7 @@ trait Hive3Inspectors { case x: DateObjectInspector => data: Any => { if (data != null) { - val y = x.getPrimitiveJavaObject(data).toEpochMilli + val y: Long = x.getPrimitiveJavaObject(data).toEpochMilli DateTimeUtils.fromJavaDate(new java.sql.Date(y)) } else { null @@ -596,7 +597,7 @@ trait Hive3Inspectors { def wrap( row: InternalRow, - wrappers: Array[(Any) => Any], + wrappers: Array[Any => Any], cache: Array[AnyRef], dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 @@ -610,7 +611,7 @@ trait Hive3Inspectors { def wrap( row: Seq[Any], - wrappers: Array[(Any) => Any], + wrappers: Array[Any => Any], cache: Array[AnyRef], dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 diff --git a/src/test/java/com/qubole/spark/datasources/hiveacid/TestHiveClient.java b/src/test/java/com/qubole/spark/hiveacid/TestHiveClient.java similarity index 84% rename from src/test/java/com/qubole/spark/datasources/hiveacid/TestHiveClient.java rename to src/test/java/com/qubole/spark/hiveacid/TestHiveClient.java index e483402..a7dc9f9 100644 --- a/src/test/java/com/qubole/spark/datasources/hiveacid/TestHiveClient.java +++ b/src/test/java/com/qubole/spark/hiveacid/TestHiveClient.java @@ -17,7 +17,7 @@ * limitations under the License. */ -package com.qubole.spark.datasources.hiveacid; +package com.qubole.spark.hiveacid; import java.sql.Connection; @@ -29,20 +29,14 @@ import java.io.StringWriter; -import java.util.logging.Level; -import java.util.logging.Logger; -import java.util.logging.*; - public class TestHiveClient { - /* - * Before running this docker container with HS2 / HMS / Hadoop running - */ - private static String driverName = "com.qubole.shaded.hive.jdbc.HiveDriver"; private static Connection con = null; private static Statement stmt = null; TestHiveClient() { try { + // Before running this docker container with HS2 / HMS / Hadoop running + String driverName = "com.qubole.shaded.hive.jdbc.HiveDriver"; Class.forName(driverName); } catch (ClassNotFoundException e) { e.printStackTrace(); @@ -69,7 +63,7 @@ public String executeQuery(String cmd) throws Exception { rs = null; } catch (Exception e) { - System.out.println("Failed execute statement "+ e); + System.out.println("Failed execute query statement \""+ cmd +"\" Error:"+ e); if (rs != null ) { rs.close(); } @@ -78,10 +72,14 @@ public String executeQuery(String cmd) throws Exception { } public void execute(String cmd) throws SQLException { - stmt.execute(cmd); + try { + stmt.execute(cmd); + } catch (Exception e) { + System.out.println("Failed execute statement \""+ cmd +"\" Error:"+ e); + } } - public String resultStr(ResultSet rs) throws SQLException { + private String resultStr(ResultSet rs) throws SQLException { StringWriter outputWriter = new StringWriter(); ResultSetMetaData rsmd = rs.getMetaData(); int columnsNumber = rsmd.getColumnCount(); @@ -104,11 +102,11 @@ public String resultStr(ResultSet rs) throws SQLException { public void teardown() throws SQLException { if (stmt != null) { stmt.close(); + stmt = null; } if (con != null) { con.close(); + con = null; } - stmt = null; - con = null; } } diff --git a/src/test/scala/com/qubole/spark/datasources/hiveacid/Table.scala b/src/test/scala/com/qubole/spark/datasources/hiveacid/Table.scala deleted file mode 100644 index 55b3ac6..0000000 --- a/src/test/scala/com/qubole/spark/datasources/hiveacid/Table.scala +++ /dev/null @@ -1,233 +0,0 @@ -/* - * Copyright 2019 Qubole, Inc. All rights reserved. - * - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.qubole.spark.datasources.hiveacid - -import org.joda.time.DateTime -import scala.collection.mutable.{ListBuffer, HashMap} -/* - * - * - */ -class Table ( - private val dbName: String, - private val tName: String, - private val extraColMap: Map[String, String], - private val tblProp: String, - private val isPartitioned: Boolean = false) { - - private var colMap = Map("key" -> "int") ++ extraColMap - private var colMapWithPartitionedCols = { - if (isPartitioned) { - Map("load_date" -> "int") ++ colMap - } else { - colMap - } - } - - // NB Add date column as well apparently always in the end - private def getRow(key: Int): String = colMap.map( x => { - x._2 match { - case "date" => s"'${(new DateTime(((key * 1000L) + 151502791900L))).toString}'" - case _ => key.toString - }}).mkString(", ") + {if (isPartitioned) s", '${(new DateTime(((key * 1000L) + 151502791900L))).toString}'" else ""} - - private def getColDefString = colMap.map(x => x._1 + " " + x._2).mkString(",") - - // FIXME: Add load_date column of partitioned table in order by clause - private def sparkOrderBy: String = sparkOrderBy(sparkTname) - private def hiveOrderBy: String = hiveOrderBy(tName) - private def sparkOrderBy(aliasedTable: String): String = - colMapWithPartitionedCols.map(x => s"${aliasedTable}.${x._1}").mkString(", ") - private def hiveOrderBy(aliasedTable: String): String = - colMapWithPartitionedCols.map(x => s"$aliasedTable.${x._1}").mkString(", ") - private def getCols = colMapWithPartitionedCols.map(x => x._1).mkString(", ") - - def getColMap = colMapWithPartitionedCols - - def hiveTname = s"$dbName.$tName" - def hiveTname1 = s"$tName" - def sparkTname = s"${dbName}.spark_${tName}" - - def hiveCreate = s"CREATE TABLE ${hiveTname} (${getColDefString}) ${tblProp}" - def hiveSelect = s"SELECT * FROM ${hiveTname} t1 ORDER BY ${hiveOrderBy("t1")}" - def hiveSelectWithPred(pred: String) = - s"SELECT * FROM ${hiveTname} t1 where ${pred} ORDER BY ${hiveOrderBy("t1")}" - def hiveSelectWithProj = s"SELECT intCol FROM ${hiveTname} ORDER BY intCol" - def hiveDrop = s"DROP TABLE IF EXISTS ${hiveTname}" - - def sparkCreate = s"CREATE TABLE ${sparkTname} USING HiveAcid OPTIONS('table' '${hiveTname}')" - def sparkSelect = s"SELECT * FROM ${sparkTname} t1 ORDER BY ${sparkOrderBy("t1")}" - def sparkSelectWithPred(pred: String) = - s"SELECT * FROM ${sparkTname} t1 where ${pred} ORDER BY ${sparkOrderBy("t1")}" - def sparkSelectWithProj = s"SELECT intCol FROM ${sparkTname} ORDER BY intCol" - def sparkDFProj = "intCol" - def sparkDrop = s"DROP TABLE IF EXISTS ${sparkTname}" - - - - def insertIntoHiveTableKeyRange(startKey: Int, endKey: Int): String = - s"INSERT INTO TABLE ${hiveTname} (${getCols}) " + (startKey to endKey).map { key => s" select ${getRow(key)} " }.mkString(" UNION ALL ") - def insertIntoHiveTableKey(key: Int): String = - s"INSERT INTO ${hiveTname} (${getCols}) VALUES (${getRow(key)})" - def deleteFromHiveTableKey(key: Int): String = - s"DELETE FROM ${hiveTname} where key = ${key}" - def updateInHiveTableKey(key: Int): String = - s"UPDATE ${hiveTname} set intCol = intCol * 10 where key = ${key}" - - def updateByMergeHiveTable = - s" merge into ${hiveTname} t using (select distinct ${getCols} from ${hiveTname}) s on s.key=t.key " + - s" when matched and s.key%2=0 then update set intCol=s.intCol * 10 " + - s" when matched and s.key%2=1 then delete " + - s" when not matched then insert values(${getRow(1000)})" - - def disableCompaction = s"ALTER TABLE ${hiveTname} SET TBLPROPERTIES ('NO_AUTO_COMPACTION' = 'true')" - def minorCompaction = s"ALTER TABLE ${hiveTname} COMPACT 'minor'" - def majorCompaction = s"ALTER TABLE ${hiveTname} COMPACT 'major'" - - def alterToTransactionalInsertOnlyTable = - s"ALTER TABLE ${hiveTname} SET TBLPROPERTIES ('transactional'='true', 'transactional_properties'='insert_only')" - def alterToTransactionalFullAcidTable = - s"ALTER TABLE ${hiveTname} SET TBLPROPERTIES ('transactional'='true', 'transactional_properties'='default')" -} - -object Table { - - // Create table string builder - // 1st param - private val partitionedStr = "PARTITIONED BY (load_date int) " - // 2nd param - private val clusteredStr = "CLUSTERED BY(key) INTO 3 BUCKETS " - // 3rd param - private val orcStr = "STORED AS ORC " - private val parquetStr = "STORED AS PARQUET " - private val textStr = "STORED AS TEXTFILE " - private val avroStr = "STORED AS AVRO " - // 4th param - private val fullAcidStr = " TBLPROPERTIES ('transactional'='true', 'transactional_properties'='default') " - private val insertOnlyStr = " TBLPROPERTIES ('transactional'='true', 'transactional_properties'='insert_only') " - - val orcFullACIDTable = orcStr + fullAcidStr - val orcPartitionedFullACIDTable = partitionedStr + orcStr + fullAcidStr - val orcBucketedPartitionedFullACIDTable = clusteredStr + partitionedStr + orcStr + fullAcidStr - val orcBucketedFullACIDTable = clusteredStr + orcStr + fullAcidStr - - val parquetFullACIDTable = parquetStr + fullAcidStr - val parquetPartitionedFullACIDTable = partitionedStr + parquetStr + fullAcidStr - val parquetBucketedPartitionedFullACIDTable = clusteredStr + partitionedStr + parquetStr + fullAcidStr - val parquetBucketedFullACIDTable = clusteredStr + parquetStr + fullAcidStr - - val textFullACIDTable = textStr + fullAcidStr - val textPartitionedFullACIDTable = partitionedStr + textStr + fullAcidStr - val textBucketedPartitionedFullACIDTable = clusteredStr + partitionedStr + textStr + fullAcidStr - val textBucketedFullACIDTable = clusteredStr + textStr + fullAcidStr - - val avroFullACIDTable = avroStr + fullAcidStr - val avroPartitionedFullACIDTable = partitionedStr + avroStr + fullAcidStr - val avroBucketedPartitionedFullACIDTable = clusteredStr + partitionedStr + avroStr + fullAcidStr - val avroBucketedFullACIDTable = clusteredStr + avroStr + fullAcidStr - - val orcInsertOnlyTable = orcStr + insertOnlyStr - val orcPartitionedInsertOnlyTable = partitionedStr + orcStr + insertOnlyStr - val orcBucketedPartitionedInsertOnlyTable = clusteredStr + partitionedStr + orcStr + insertOnlyStr - val orcBucketedInsertOnlyTable = clusteredStr + orcStr + insertOnlyStr - - val parquetInsertOnlyTable = parquetStr + insertOnlyStr - val parquetPartitionedInsertOnlyTable = partitionedStr + parquetStr + insertOnlyStr - val parquetBucketedPartitionedInsertOnlyTable = clusteredStr + partitionedStr + parquetStr + insertOnlyStr - val parquetBucketedInsertOnlyTable = clusteredStr + parquetStr + insertOnlyStr - - val textInsertOnlyTable = textStr + insertOnlyStr - val textPartitionedInsertOnlyTable = partitionedStr + textStr + insertOnlyStr - val textBucketedPartitionedInsertOnlyTable = clusteredStr + partitionedStr + textStr + insertOnlyStr - val textBucketedInsertOnlyTable = clusteredStr + textStr + insertOnlyStr - - val avroInsertOnlyTable = avroStr + insertOnlyStr - val avroPartitionedInsertOnlyTable = partitionedStr + avroStr + insertOnlyStr - val avroBucketedPartitionedInsertOnlyTable = clusteredStr + partitionedStr + avroStr + insertOnlyStr - val avroBucketedInsertOnlyTable = clusteredStr + avroStr + insertOnlyStr - - private def generateTableVariations(fileFormatTypes: Array[String], - partitionedTypes: Array[String], - clusteredTypes: Array[String], - acidTypes: Array[String]): List[(String, Boolean)] = { - var tblTypes = new ListBuffer[(String, Boolean)]() - for (fileFormat <- fileFormatTypes) { - for (partitioned <- partitionedTypes) { - for (clustered <- clusteredTypes) { - for (acidType <- acidTypes) { - val tType = partitioned + clustered + fileFormat + acidType - if (partitioned != "") { - tblTypes += ((tType, true)) - } else { - tblTypes += ((tType, false)) - } - } - } - } - } - tblTypes.filter { - case (name, isPartitioned) => - // Filter out all non-orc, full acid tables - !(!name.toLowerCase().contains("orc") && name.contains(fullAcidStr)) - }.toList - } - - // Loop through all variations - def allFullAcidTypes(): List[(String, Boolean)] = { - val acidType = fullAcidStr - val fileFormatTypes = Array(orcStr) - val partitionedTypes = Array("", partitionedStr) - //val partitionedTypes = Array("") - val clusteredTypes = Array("") - - generateTableVariations(fileFormatTypes, partitionedTypes, clusteredTypes, Array(acidType)) - } - - // Loop through all variations - def allInsertOnlyTypes(): List[(String, Boolean)] = { - val acidType = insertOnlyStr - // NB: Avro not supported !!! - val fileFormatTypes = Array(orcStr) - val partitionedTypes = Array("", partitionedStr) - // val partitionedTypes = Array("") - val clusteredTypes = Array("") - - generateTableVariations(fileFormatTypes, partitionedTypes, clusteredTypes, Array(acidType)) - } - - def allNonAcidTypes(): List[(String, Boolean)] = { - val acidType = "" - val fileFormatTypes = Array(orcStr) - val partitionedTypes = Array("", partitionedStr) - //val partitionedTypes = Array("") - val clusteredTypes = Array("") - - generateTableVariations(fileFormatTypes, partitionedTypes, clusteredTypes, Array(acidType)) - } - - - def hiveJoin(table1: Table, table2: Table): String = { - s"SELECT * FROM ${table1.hiveTname} t1 JOIN ${table2.hiveTname} t2 WHERE t1.key = t2.key ORDER BY ${table1.hiveOrderBy("t1")} , ${table2.hiveOrderBy("t2")}" - } - - def sparkJoin(table1: Table, table2: Table): String = { - s"SELECT * FROM ${table1.sparkTname} t1 JOIN ${table2.sparkTname} t2 WHERE t1.key = t2.key ORDER BY ${table1.sparkOrderBy("t1")} , ${table2.sparkOrderBy("t2")}" - } -} diff --git a/src/test/scala/com/qubole/spark/datasources/hiveacid/HiveAcidSuite.scala b/src/test/scala/com/qubole/spark/hiveacid/ReadSuite.scala similarity index 80% rename from src/test/scala/com/qubole/spark/datasources/hiveacid/HiveAcidSuite.scala rename to src/test/scala/com/qubole/spark/hiveacid/ReadSuite.scala index 4df6f08..83dc5ba 100644 --- a/src/test/scala/com/qubole/spark/datasources/hiveacid/HiveAcidSuite.scala +++ b/src/test/scala/com/qubole/spark/hiveacid/ReadSuite.scala @@ -17,29 +17,26 @@ * limitations under the License. */ -package com.qubole.spark.datasources.hiveacid +package com.qubole.spark.hiveacid -import org.apache.commons.logging.LogFactory -import org.apache.log4j.{Level, LogManager} -import org.apache.spark.internal.Logging +import org.apache.log4j.{Level, LogManager, Logger} import org.apache.spark.sql._ -import org.apache.spark.util._ import org.scalatest._ import scala.util.control.NonFatal -class HiveACIDSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { +class ReadSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { - val log = LogManager.getLogger(this.getClass) + val log: Logger = LogManager.getLogger(this.getClass) log.setLevel(Level.INFO) - var helper: TestHelper = _; - val isDebug = false + var helper: TestHelper = _ + val isDebug = true val DEFAULT_DBNAME = "HiveTestDB" val defaultPred = " intCol < 5 " - val cols = Map( + val cols: Map[String, String] = Map( ("intCol","int"), ("doubleCol","double"), ("floatCol","float"), @@ -50,7 +47,7 @@ class HiveACIDSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfter override def beforeAll() { try { - helper = new TestHelper(); + helper = new TestHelper() if (isDebug) { log.setLevel(Level.DEBUG) } @@ -58,7 +55,7 @@ class HiveACIDSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfter // DB helper.hiveExecute("DROP DATABASE IF EXISTS "+ DEFAULT_DBNAME +" CASCADE") - helper.hiveExecute("CREATE DATABASE "+ DEFAULT_DBNAME) + helper.hiveExecute("CREATE DATABASE IF NOT EXISTS "+ DEFAULT_DBNAME) } catch { case NonFatal(e) => log.info("failed " + e) } @@ -68,23 +65,27 @@ class HiveACIDSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfter helper.destroy() } - // Test Run - readTest(Table.allFullAcidTypes, false) - readTest(Table.allInsertOnlyTypes, true) + readTest(Table.allFullAcidTypes(), insertOnly = false) + readTest(Table.allInsertOnlyTypes(), insertOnly = true) // NB: Cannot create merged table for insert only table - mergeTest(Table.allFullAcidTypes, false) + // mergeTest(Table.allFullAcidTypes, false) joinTest(Table.allFullAcidTypes(), Table.allFullAcidTypes()) joinTest(Table.allInsertOnlyTypes(), Table.allFullAcidTypes()) joinTest(Table.allInsertOnlyTypes(), Table.allInsertOnlyTypes()) - compactionTest(Table.allFullAcidTypes(), false) - compactionTest(Table.allInsertOnlyTypes(), true) + compactionTest(Table.allFullAcidTypes(), insertOnly = false) + compactionTest(Table.allInsertOnlyTypes(), insertOnly = true) // NB: No run for the insert only table. - nonAcidToAcidConversionTest(Table.allNonAcidTypes(), false) + nonAcidToFullAcidConversionTest(List( + (Table.orcTable, false), + (Table.orcPartitionedTable, true), + (Table.orcBucketedTable, false), + (Table.orcBucketedPartitionedTable, true) + )) // Run predicatePushdown test for InsertOnly/FullAcid, Partitioned/NonPartitioned tables // It should work in file formats which supports predicate pushdown - orc/parquet @@ -95,7 +96,7 @@ class HiveACIDSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfter (Table.orcInsertOnlyTable, false, true), (Table.parquetInsertOnlyTable, false, true), (Table.textInsertOnlyTable, false, false), - (Table.orcFullACIDTable, false, true), + (Table.orcFullACIDTable, false, true), (Table.orcPartitionedFullACIDTable, true, true) )) @@ -111,7 +112,7 @@ class HiveACIDSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfter val testName = "Simple Read Test for " + tName + " type " + tType test(testName) { val table = new Table(DEFAULT_DBNAME, tName, cols, tType, isPartitioned) - def code() = { + def code(): Unit = { helper.recreate(table) helper.hiveExecute(table.insertIntoHiveTableKeyRange(1, 10)) helper.verify(table, insertOnly) @@ -129,18 +130,18 @@ class HiveACIDSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfter val table = new Table(DEFAULT_DBNAME, tName, cols, tType, isPartitioned) def checkOutputRowsInLeafNode(df: DataFrame): Long = { - val tableScanNode = df.queryExecution.executedPlan.collectLeaves()(0) + val tableScanNode = df.queryExecution.executedPlan.collectLeaves().head val metricsMap = tableScanNode.metrics val dfRowsRead = metricsMap("numOutputRows").value log.info(s"dfRowsRead: $dfRowsRead") - return dfRowsRead + dfRowsRead } - def code() = { - helper.withSQLConf("spark.sql.acidDs.enablePredicatePushdown" -> "true") { - helper.recreate(table, true) + def code(): Unit = { + helper.withSQLConf("spark.sql.hiveAcid.enablePredicatePushdown" -> "true") { + helper.recreate(table) // Inserting 5 rows in different hive queries so that we will have 5 files - one for each row - (3 to 7).toSeq.foreach(k => helper.hiveExecute(table.insertIntoHiveTableKey(k))) + (3 to 7).foreach(k => helper.hiveExecute(table.insertIntoHiveTableKey(k))) val dfFromSql = helper.sparkSQL(table.sparkSelectWithPred(defaultPred)) val hiveResStr = helper.hiveExecuteQuery(table.hiveSelectWithPred(defaultPred)) @@ -155,7 +156,7 @@ class HiveACIDSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfter // sparkSQL("select count(*) FROM HiveTestDB.spark_t1 t1 where intCol < 5").collect() // Disable the pushdown - helper.sparkSQL("set spark.sql.acidDs.enablePredicatePushdown=false") + helper.sparkSQL("set spark.sql.hiveAcid.enablePredicatePushdown=false") val dfFromSql1 = helper.sparkSQL(table.sparkSelectWithPred(defaultPred)) helper.compareResult(hiveResStr, dfFromSql1.collect()) assert(checkOutputRowsInLeafNode(dfFromSql1) == 2L * 5) @@ -202,14 +203,14 @@ class HiveACIDSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfter // 4. Alter table and convert the table into ACID table. // 5. Create spark sym link table over the hive table. // VERIFY: Both spark reads are same as hive read - def nonAcidToAcidConversionTest(tTypes: List[(String,Boolean)], insertOnly: Boolean): Unit = { + def nonAcidToFullAcidConversionTest(tTypes: List[(String,Boolean)]): Unit = { tTypes.foreach { case (tType, isPartitioned) => val tName = "t1" val testName = "NonAcid to Acid conversion test for " + tName + " type " + tType test(testName) { val table = new Table(DEFAULT_DBNAME, tName, cols, tType, isPartitioned) - def code() = { - helper.recreate(table, false) + def code(): Unit = { + helper.recreate(table, createSymlinkSparkTables = false) helper.hiveExecute(table.insertIntoHiveTableKeyRange(1, 10)) val hiveResStr = helper.hiveExecuteQuery(table.hiveSelect) @@ -224,7 +225,7 @@ class HiveACIDSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfter helper.compareResult(hiveResStr, dfFromSql.collect()) helper.compareResult(hiveResStr, dfFromScala.collect()) - helper.verify(table, insertOnly) + helper.verify(table, insertOnly = false) } helper.myRun(testName, code) } @@ -233,7 +234,7 @@ class HiveACIDSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfter // Compaction Test // - // 1. Disable comaction on the table. + // 1. Disable compaction on the table. // 2. Insert bunch of rows into the table. // 4. Read entire table using hive client // 5. Delete few keys, to create delete delta files. @@ -261,12 +262,12 @@ class HiveACIDSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfter val testName = "Simple Compaction Test for " + tName + " type " + tType test(testName) { val table = new Table(DEFAULT_DBNAME, tName, cols, tType, isPartitioned) - def code() = { + def code(): Unit = { helper.recreate(table) helper.hiveExecute(table.disableCompaction) - helper.hiveExecute(table.insertIntoHiveTableKeyRange(1, 3)) + helper.hiveExecute(table.insertIntoHiveTableKeyRange(1, 10)) val hiveResStr = helper.hiveExecuteQuery(table.hiveSelect) @@ -281,7 +282,11 @@ class HiveACIDSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfter helper.hiveExecute(table.insertIntoHiveTableKey(13)) helper.hiveExecute(table.insertIntoHiveTableKey(14)) helper.hiveExecute(table.insertIntoHiveTableKey(15)) - compactAndTest(hiveResStr, df1, df2) + if (isPartitioned) { + compactPartitionedAndTest(hiveResStr, df1, df2, Seq(11,12,13,14,15)) + } else { + compactAndTest(hiveResStr, df1, df2) + } // Shortcut for insert Only if (! insertOnly) { @@ -289,17 +294,25 @@ class HiveACIDSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfter helper.hiveExecute(table.deleteFromHiveTableKey(4)) helper.hiveExecute(table.deleteFromHiveTableKey(5)) helper.hiveExecute(table.deleteFromHiveTableKey(6)) - compactAndTest(hiveResStr, df1, df2) + if (isPartitioned) { + compactPartitionedAndTest(hiveResStr, df1, df2, Seq(3,4,5,6)) + } else { + compactAndTest(hiveResStr, df1, df2) + } helper.hiveExecute(table.updateInHiveTableKey(7)) helper.hiveExecute(table.updateInHiveTableKey(8)) helper.hiveExecute(table.updateInHiveTableKey(9)) helper.hiveExecute(table.updateInHiveTableKey(10)) - compactAndTest(hiveResStr, df1, df2) + if (isPartitioned) { + compactPartitionedAndTest(hiveResStr, df1, df2, Seq(7,8,9,10)) + } else { + compactAndTest(hiveResStr, df1, df2) + } } } - def compactAndTest(hiveResStr: String, df1: DataFrame, df2: DataFrame) = { + def compactAndTest(hiveResStr: String, df1: DataFrame, df2: DataFrame): Unit = { helper.compareResult(hiveResStr, df1.collect()) helper.compareResult(hiveResStr, df2.collect()) helper.hiveExecute(table.minorCompaction) @@ -310,6 +323,17 @@ class HiveACIDSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfter helper.compareResult(hiveResStr, df2.collect()) } + def compactPartitionedAndTest(hiveResStr: String, df1: DataFrame, df2: DataFrame, keys: Seq[Int]): Unit = { + helper.compareResult(hiveResStr, df1.collect()) + helper.compareResult(hiveResStr, df2.collect()) + keys.foreach(k => helper.hiveExecute(table.minorPartitionCompaction(k))) + helper.compareResult(hiveResStr, df1.collect()) + helper.compareResult(hiveResStr, df2.collect()) + keys.foreach((k: Int) => helper.hiveExecute(table.majorPartitionCompaction(k))) + helper.compareResult(hiveResStr, df1.collect()) + helper.compareResult(hiveResStr, df2.collect()) + } + helper.myRun(testName, code) } } @@ -331,7 +355,7 @@ class HiveACIDSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfter test(testName) { val table1 = new Table(DEFAULT_DBNAME, tName1, cols, tType1, isPartitioned1) val table2 = new Table(DEFAULT_DBNAME, tName2, cols, tType2, isPartitioned2) - def code() = { + def code(): Unit = { helper.recreate(table1) helper.recreate(table2) diff --git a/src/test/scala/com/qubole/spark/hiveacid/Table.scala b/src/test/scala/com/qubole/spark/hiveacid/Table.scala new file mode 100644 index 0000000..247a280 --- /dev/null +++ b/src/test/scala/com/qubole/spark/hiveacid/Table.scala @@ -0,0 +1,250 @@ +/* + * Copyright 2019 Qubole, Inc. All rights reserved. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid + +import org.joda.time.DateTime +import scala.collection.mutable.{ListBuffer, HashMap} + +/* + * + * + */ +class Table ( + private val dbName: String, + private val tName: String, + private val extraColMap: Map[String, String], + private val tblProp: String, + val isPartitioned: Boolean = false) { + + private var colMap = Map("key" -> "int") ++ extraColMap + private var colMapWithPartitionedCols = if (isPartitioned) { + Map("ptnCol" -> "int") ++ colMap + } else { + colMap + } + + // NB Add date column as well apparently always in the end + private def getRow(key: Int): String = colMap.map( x => { + x._2 match { + case "date" => s"'${(new DateTime(((key * 1000L) + 151502791900L))).toString}'" + case _ => { + x._1 match { + case "ptnCol" => (key % 3).toString + case _ => key.toString + } + } + }}).mkString(", ") + {if (isPartitioned) s", '${(new DateTime(((key * 1000L) + 151502791900L))).toString}'" else ""} + + private def getColDefString = colMap.map(x => x._1 + " " + x._2).mkString(",") + + // FIXME: Add ptn_col column of partitioned table in order by clause + // private def sparkOrderBy: String = sparkOrderBy(sparkTname) + // private def hiveOrderBy: String = hiveOrderBy(tName) + private def sparkOrderBy(aliasedTable: String): String = + colMapWithPartitionedCols.map(x => s"${aliasedTable}.${x._1}").mkString(", ") + private def hiveOrderBy(aliasedTable: String): String = + colMapWithPartitionedCols.map(x => s"$aliasedTable.${x._1}").mkString(", ") + private def getCols = colMapWithPartitionedCols.map(x => x._1).mkString(", ") + + def getColMap = colMapWithPartitionedCols + + def hiveTname = s"$dbName.$tName" + def hiveTname1 = s"$tName" + def sparkTname = s"${dbName}.spark_${tName}" + + def hiveCreate = s"CREATE TABLE ${hiveTname} (${getColDefString}) ${tblProp}" + def hiveSelect = s"SELECT * FROM ${hiveTname} t1 ORDER BY ${hiveOrderBy("t1")}" + def hiveSelectWithPred(pred: String) = + s"SELECT * FROM ${hiveTname} t1 where ${pred} ORDER BY ${hiveOrderBy("t1")}" + def hiveSelectWithProj = s"SELECT intCol FROM ${hiveTname} ORDER BY intCol" + def hiveDrop = s"DROP TABLE IF EXISTS ${hiveTname}" + + def sparkCreate = s"CREATE TABLE ${sparkTname} USING HiveAcid OPTIONS('table' '${hiveTname}')" + def sparkSelect = s"SELECT * FROM ${sparkTname} t1 ORDER BY ${sparkOrderBy("t1")}" + def sparkSelectWithPred(pred: String) = + s"SELECT * FROM ${sparkTname} t1 where ${pred} ORDER BY ${sparkOrderBy("t1")}" + def sparkSelectWithProj = s"SELECT intCol FROM ${sparkTname} ORDER BY intCol" + def sparkDFProj = "intCol" + def sparkDrop = s"DROP TABLE IF EXISTS ${sparkTname}" + + + private def insertHiveTableKeyRange(startKey: Int, endKey: Int, operation: String): String = + s"INSERT ${operation} TABLE ${hiveTname} " + (startKey to endKey).map { key => s" select ${getRow(key)} " }.mkString(" UNION ALL ") + private def insertSparkTableKeyRange(startKey: Int, endKey: Int, operation: String): String = + s"INSERT ${operation} TABLE ${sparkTname} " + (startKey to endKey).map { key => s" select ${getRow(key)} " }.mkString(" UNION ALL ") + + def insertIntoHiveTableKeyRange(startKey: Int, endKey: Int): String = + insertHiveTableKeyRange(startKey, endKey, "INTO") + def insertIntoSparkTableKeyRange(startKey: Int, endKey: Int): String = + insertSparkTableKeyRange(startKey, endKey, "INTO") + def insertOverwriteHiveTableKeyRange(startKey: Int, endKey: Int): String = + insertHiveTableKeyRange(startKey, endKey, "OVERWRITE") + def insertOverwriteSparkTableKeyRange(startKey: Int, endKey: Int): String = + insertSparkTableKeyRange(startKey, endKey, "OVERWRITE") + + def insertIntoHiveTableKey(key: Int): String = + s"INSERT INTO ${hiveTname} (${getCols}) VALUES (${getRow(key)})" + def deleteFromHiveTableKey(key: Int): String = + s"DELETE FROM ${hiveTname} where key = ${key}" + def updateInHiveTableKey(key: Int): String = + s"UPDATE ${hiveTname} set intCol = intCol * 10 where key = ${key}" + + def updateByMergeHiveTable = + s" merge into ${hiveTname} t using (select distinct ${getCols} from ${hiveTname}) s on s.key=t.key " + + s" when matched and s.key%2=0 then update set intCol=s.intCol * 10 " + + s" when matched and s.key%2=1 then delete " + + s" when not matched then insert values(${getRow(1000)})" + + def disableCompaction = s"ALTER TABLE ${hiveTname} SET TBLPROPERTIES ('NO_AUTO_COMPACTION' = 'true')" + def disableCleanup = s"ALTER TABLE ${hiveTname} SET TBLPROPERTIES ('NO_CLEANUP' = 'true')" + def minorCompaction = s"ALTER TABLE ${hiveTname} COMPACT 'minor'" + def majorCompaction = s"ALTER TABLE ${hiveTname} COMPACT 'major'" + + def minorPartitionCompaction(ptnid: Int): String = { + s"ALTER TABLE ${hiveTname} PARTITION(ptnCol=${ptnid}) COMPACT 'minor'" + } + + def majorPartitionCompaction(ptnid: Int): String = { + s"ALTER TABLE ${hiveTname} PARTITION(ptnCol=${ptnid}) COMPACT 'major'" + } + + def alterToTransactionalInsertOnlyTable = + s"ALTER TABLE ${hiveTname} SET TBLPROPERTIES ('transactional'='true', 'transactional_properties'='insert_only')" + def alterToTransactionalFullAcidTable = + s"ALTER TABLE ${hiveTname} SET TBLPROPERTIES ('transactional'='true', 'transactional_properties'='default')" +} + +object Table { + + // Create table string builder + // 1st param + private val partitionedStr = "PARTITIONED BY (ptnCol int) " + // 2nd param + private val clusteredStr = "CLUSTERED BY(key) INTO 3 BUCKETS " + // 3rd param + private val orcStr = "STORED AS ORC " + private val parquetStr = "STORED AS PARQUET " + private val textStr = "STORED AS TEXTFILE " + private val avroStr = "STORED AS AVRO " + // 4th param + private val fullAcidStr = " TBLPROPERTIES ('transactional'='true', 'transactional_properties'='default') " + private val insertOnlyStr = " TBLPROPERTIES ('transactional'='true', 'transactional_properties'='insert_only') " + + val orcTable = orcStr + val orcPartitionedTable = partitionedStr + orcStr + val orcBucketedPartitionedTable = partitionedStr + clusteredStr + orcStr + val orcBucketedTable = clusteredStr + orcStr + + val orcFullACIDTable = orcStr + fullAcidStr + val orcPartitionedFullACIDTable = partitionedStr + orcStr + fullAcidStr + val orcBucketedPartitionedFullACIDTable = partitionedStr + clusteredStr + orcStr + fullAcidStr + val orcBucketedFullACIDTable = clusteredStr + orcStr + fullAcidStr + + val parquetFullACIDTable = parquetStr + fullAcidStr + val parquetPartitionedFullACIDTable = partitionedStr + parquetStr + fullAcidStr + val parquetBucketedPartitionedFullACIDTable = partitionedStr + clusteredStr + parquetStr + fullAcidStr + val parquetBucketedFullACIDTable = clusteredStr + parquetStr + fullAcidStr + + val textFullACIDTable = textStr + fullAcidStr + val textPartitionedFullACIDTable = partitionedStr + textStr + fullAcidStr + val textBucketedPartitionedFullACIDTable = partitionedStr + clusteredStr + textStr + fullAcidStr + val textBucketedFullACIDTable = clusteredStr + textStr + fullAcidStr + + val avroFullACIDTable = avroStr + fullAcidStr + val avroPartitionedFullACIDTable = partitionedStr + avroStr + fullAcidStr + val avroBucketedPartitionedFullACIDTable = partitionedStr + clusteredStr + avroStr + fullAcidStr + val avroBucketedFullACIDTable = clusteredStr + avroStr + fullAcidStr + + val orcInsertOnlyTable = orcStr + insertOnlyStr + val orcPartitionedInsertOnlyTable = partitionedStr + orcStr + insertOnlyStr + val orcBucketedPartitionedInsertOnlyTable = partitionedStr + clusteredStr + orcStr + insertOnlyStr + val orcBucketedInsertOnlyTable = clusteredStr + orcStr + insertOnlyStr + + val parquetInsertOnlyTable = parquetStr + insertOnlyStr + val parquetPartitionedInsertOnlyTable = partitionedStr + parquetStr + insertOnlyStr + val parquetBucketedPartitionedInsertOnlyTable = partitionedStr + clusteredStr + parquetStr + insertOnlyStr + val parquetBucketedInsertOnlyTable = clusteredStr + parquetStr + insertOnlyStr + + val textInsertOnlyTable = textStr + insertOnlyStr + val textPartitionedInsertOnlyTable = partitionedStr + textStr + insertOnlyStr + val textBucketedPartitionedInsertOnlyTable = partitionedStr + clusteredStr + textStr + insertOnlyStr + val textBucketedInsertOnlyTable = clusteredStr + textStr + insertOnlyStr + + val avroInsertOnlyTable = avroStr + insertOnlyStr + val avroPartitionedInsertOnlyTable = partitionedStr + avroStr + insertOnlyStr + val avroBucketedPartitionedInsertOnlyTable = partitionedStr + clusteredStr + avroStr + insertOnlyStr + val avroBucketedInsertOnlyTable = clusteredStr + avroStr + insertOnlyStr + + private def generateTableVariations(fileFormatTypes: Array[String], + partitionedTypes: Array[String], + clusteredTypes: Array[String], + acidTypes: Array[String]): List[(String, Boolean)] = { + var tblTypes = new ListBuffer[(String, Boolean)]() + for (fileFormat <- fileFormatTypes) { + for (partitioned <- partitionedTypes) { + for (clustered <- clusteredTypes) { + for (acidType <- acidTypes) { + val tType = partitioned + clustered + fileFormat + acidType + if (partitioned != "") { + tblTypes += ((tType, true)) + } else { + tblTypes += ((tType, false)) + } + } + } + } + } + tblTypes.filter { + case (name, isPartitioned) => + // Filter out all non-orc, full acid tables + !(!name.toLowerCase().contains("orc") && name.contains(fullAcidStr)) + }.toList + } + + // Loop through all variations + def allFullAcidTypes(): List[(String, Boolean)] = { + val acidType = fullAcidStr + val fileFormatTypes = Array(orcStr) + val partitionedTypes = Array("", partitionedStr) + val clusteredTypes = Array("", clusteredStr) + + generateTableVariations(fileFormatTypes, partitionedTypes, clusteredTypes, Array(acidType)) + } + + // Loop through all variations + def allInsertOnlyTypes(): List[(String, Boolean)] = { + val acidType = insertOnlyStr + // NB: Avro not supported !!! + val fileFormatTypes = Array(orcStr, parquetStr, textStr) + val partitionedTypes = Array("", partitionedStr) + val clusteredTypes = Array("", clusteredStr) + + generateTableVariations(fileFormatTypes, partitionedTypes, clusteredTypes, Array(acidType)) + } + + + def hiveJoin(table1: Table, table2: Table): String = { + s"SELECT * FROM ${table1.hiveTname} t1 JOIN ${table2.hiveTname} t2 WHERE t1.key = t2.key ORDER BY ${table1.hiveOrderBy("t1")} , ${table2.hiveOrderBy("t2")}" + } + + def sparkJoin(table1: Table, table2: Table): String = { + s"SELECT * FROM ${table1.sparkTname} t1 JOIN ${table2.sparkTname} t2 WHERE t1.key = t2.key ORDER BY ${table1.sparkOrderBy("t1")} , ${table2.sparkOrderBy("t2")}" + } +} diff --git a/src/test/scala/com/qubole/spark/datasources/hiveacid/TestHelper.scala b/src/test/scala/com/qubole/spark/hiveacid/TestHelper.scala similarity index 79% rename from src/test/scala/com/qubole/spark/datasources/hiveacid/TestHelper.scala rename to src/test/scala/com/qubole/spark/hiveacid/TestHelper.scala index e1af030..6058a9a 100644 --- a/src/test/scala/com/qubole/spark/datasources/hiveacid/TestHelper.scala +++ b/src/test/scala/com/qubole/spark/hiveacid/TestHelper.scala @@ -17,18 +17,13 @@ * limitations under the License. */ -package com.qubole.spark.datasources.hiveacid +package com.qubole.spark.hiveacid -import java.io.StringWriter import java.net.URLClassLoader import java.net.URL -import org.apache.commons.logging.LogFactory -import org.apache.log4j.{Level, LogManager} -import org.apache.spark.internal.Logging +import org.apache.log4j.{Level, LogManager, Logger} import org.apache.spark.sql._ -import org.apache.spark.util._ - import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf @@ -45,7 +40,7 @@ class TestHelper { def init(isDebug: Boolean) { verbose = isDebug // Clients - spark = TestSparkSession.getSession() + spark = TestSparkSession.getSession if (verbose) { log.setLevel(Level.DEBUG) } @@ -68,17 +63,18 @@ class TestHelper { // 3. Read entire table using spark dataframe API // Verify: Both spark reads are same as hive read - // Simple + // Check the data present in this table via hive as well as spark sql and df private def compare(table: Table, msg: String): Unit = { - log.info(s"Verify simple ${msg}") + log.info(s"Verify simple $msg") val hiveResStr = hiveExecuteQuery(table.hiveSelect) val (dfFromSql, dfFromScala) = sparkGetDF(table) compareResult(hiveResStr, dfFromSql.collect()) compareResult(hiveResStr, dfFromScala.collect()) } + // With Predicate private def compareWithPred(table: Table, msg: String, pred: String): Unit = { - log.info(s"Verify with predicate ${msg}") + log.info(s"Verify with predicate $msg") val hiveResStr = hiveExecuteQuery(table.hiveSelectWithPred(pred)) val (dfFromSql, dfFromScala) = sparkGetDFWithPred(table, pred) compareResult(hiveResStr, dfFromSql.collect()) @@ -86,13 +82,42 @@ class TestHelper { } // With Projection private def compareWithProj(table: Table, msg: String): Unit = { - log.info(s"Verify with projection ${msg}") + log.info(s"Verify with projection $msg") val hiveResStr = hiveExecuteQuery(table.hiveSelectWithProj) val (dfFromSql, dfFromScala) = sparkGetDFWithProj(table) compareResult(hiveResStr, dfFromSql.collect()) compareResult(hiveResStr, dfFromScala.collect()) } + // Compare result of 2 tables via hive + def compareTwoTablesViaHive(table1: Table, table2: Table, msg: String, + expectedRows: Int = -1): Unit = { + log.info(s"Verify output of 2 tables via Hive: $msg") + val hiveResStr1 = hiveExecuteQuery(table1.hiveSelect) + val hiveResStr2 = hiveExecuteQuery(table2.hiveSelect) + assert(hiveResStr1 == hiveResStr2, s"out1: \n$hiveResStr1\nout2: \n$hiveResStr2\n") + if (expectedRows != -1) { + val resultRows = hiveResStr1.split("\n").length + assert(resultRows == expectedRows, s"Expected $expectedRows rows, got $resultRows rows " + + s"in output:\n$hiveResStr1") + } + } + + // Compare result of 2 tables via spark + def compareTwoTablesViaSpark(table1: Table, table2: Table, msg: String, + expectedRows: Int = -1): Unit = { + log.info(s"Verify output of 2 tables via Spark: $msg") + val sparkResRows1 = sparkCollect(table1.hiveSelect) + val sparkResRows2 = sparkCollect(table2.hiveSelect) + compareResult(sparkResRows1, sparkResRows2) + if (expectedRows != -1) { + val result = sparkRowsToStr(sparkResRows1) + val resultRows = result.split("\n").length + assert(resultRows == expectedRows, s"Expected $expectedRows rows, got $resultRows rows " + + s"in output:\n$result") + } + } + // 1. Insert some more rows into the table using hive client. // 2. Compare simple // 3. Delete some rows from the table using hive client. @@ -169,7 +194,7 @@ class TestHelper { var dfScala = spark.read.format("HiveAcid").options(Map("table" -> table.hiveTname)).load().select(table.sparkDFProj) dfScala = totalOrderBy(table, dfScala) - return (dfSql, dfScala) + (dfSql, dfScala) } def sparkGetDFWithPred(table: Table, pred: String): (DataFrame, DataFrame) = { @@ -177,7 +202,7 @@ class TestHelper { var dfScala = spark.read.format("HiveAcid").options(Map("table" -> table.hiveTname)).load().where(col("intCol") < "5") dfScala = totalOrderBy(table, dfScala) - return (dfSql, dfScala) + (dfSql, dfScala) } def sparkGetDF(table: Table): (DataFrame, DataFrame) = { @@ -185,26 +210,26 @@ class TestHelper { var dfScala = spark.read.format("HiveAcid").options(Map("table" -> table.hiveTname)).load() dfScala = totalOrderBy(table, dfScala) - return (dfSql, dfScala) + (dfSql, dfScala) } def sparkSQL(cmd: String): DataFrame = { - log.debug(s"Spark> ${cmd}\n") + log.debug(s"Spark> $cmd\n") spark.sql(cmd) } def sparkCollect(cmd: String): Array[Row] = { - log.debug(s"Spark> ${cmd}\n") + log.debug(s"Spark> $cmd\n") spark.sql(cmd).collect() } def hiveExecute(cmd: String): Any = { - log.debug(s"Hive> ${cmd}\n"); + log.debug(s"Hive> $cmd\n") hiveClient.execute(cmd) } def hiveExecuteQuery(cmd: String): String = { - log.debug(s"Hive> ${cmd}\n"); + log.debug(s"Hive> $cmd\n") hiveClient.executeQuery(cmd) } @@ -231,6 +256,14 @@ class TestHelper { assert(hiveResStr == sparkResStr) } + // Compare the results + def compareResult(sparkRes1: Array[Row], sparkRes2: Array[Row]): Unit = { + val sparkResStr1 = sparkRowsToStr(sparkRes1) + val sparkResStr2 = sparkRowsToStr(sparkRes2) + log.debug(s"Comparing \n hive: $sparkResStr1 \n Spark: $sparkResStr2") + assert(sparkResStr1 == sparkResStr2) + } + // Convert Array of Spark Rows into a String private def sparkRowsToStr(rows: Array[Row]): String = { rows.map(row => row.mkString(",")).mkString("\n") @@ -245,7 +278,7 @@ class TestHelper { code() } catch { case NonFatal(e) => - log.info(s"Failed test[${testName}]:$e") + log.info(s"Failed test[$testName]:$e") throw e } } @@ -275,10 +308,10 @@ class TestHelper { } // Given a className, identify all the jars in classpath that contains the class - def getJarsForClass(className: String): Unit = { + def printJarsForClass(className: String): Unit = { def list_urls(cl: ClassLoader): Array[java.net.URL] = cl match { case null => Array() - case u: java.net.URLClassLoader => u.getURLs() ++ list_urls(cl.getParent) + case u: java.net.URLClassLoader => u.getURLs ++ list_urls(cl.getParent) case _ => list_urls(cl.getParent) } @@ -291,27 +324,27 @@ class TestHelper { resultJarsArray = resultJarsArray :+ classPath } } - return resultJarsArray + resultJarsArray } val allJars = list_urls(getClass.getClassLoader).distinct val requiredJars = findJarsHavingClass(className, allJars) - log.info(s"Class: $className found in following ${requiredJars.size} jars:") + log.info(s"Class: $className found in following ${requiredJars.length} jars:") requiredJars.foreach(uri => log.info(uri.toString)) } } object TestHelper { - val log = LogManager.getLogger(this.getClass) + val log: Logger = LogManager.getLogger(this.getClass) log.setLevel(Level.INFO) // Given a className, identify all the jars in classpath that contains the class - def getJarsForClass(className: String): Unit = { + def printJarsForClass(className: String): Unit = { def list_urls(cl: ClassLoader): Array[java.net.URL] = cl match { case null => Array() - case u: java.net.URLClassLoader => u.getURLs() ++ list_urls(cl.getParent) + case u: java.net.URLClassLoader => u.getURLs ++ list_urls(cl.getParent) case _ => list_urls(cl.getParent) } @@ -324,13 +357,13 @@ object TestHelper { resultJarsArray = resultJarsArray :+ classPath } } - return resultJarsArray + resultJarsArray } val allJars = list_urls(getClass.getClassLoader).distinct val requiredJars = findJarsHavingClass(className, allJars) - log.info(s"Class: $className found in following ${requiredJars.size} jars:") + log.info(s"Class: $className found in following ${requiredJars.length} jars:") requiredJars.foreach(uri => log.info(uri.toString)) } } diff --git a/src/test/scala/com/qubole/spark/datasources/hiveacid/TestSparkSession.scala b/src/test/scala/com/qubole/spark/hiveacid/TestSparkSession.scala similarity index 83% rename from src/test/scala/com/qubole/spark/datasources/hiveacid/TestSparkSession.scala rename to src/test/scala/com/qubole/spark/hiveacid/TestSparkSession.scala index bbce67f..f0b8e4c 100644 --- a/src/test/scala/com/qubole/spark/datasources/hiveacid/TestSparkSession.scala +++ b/src/test/scala/com/qubole/spark/hiveacid/TestSparkSession.scala @@ -16,23 +16,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.qubole.spark.datasources.hiveacid +package com.qubole.spark.hiveacid import org.apache.spark.sql.SparkSession private[hiveacid] object TestSparkSession { - val spark = SparkSession.builder().appName("Hive-acid-test") + val spark: SparkSession = SparkSession.builder().appName("Hive-acid-test") .master("local[*]") .config("spark.hadoop.hive.metastore.uris", "thrift://0.0.0.0:10000") .config("spark.sql.warehouse.dir", "/tmp") + .config("spark.sql.extensions", "com.qubole.spark.hiveacid.HiveAcidAutoConvertExtension") //.config("spark.ui.enabled", "true") //.config("spark.ui.port", "4041") .enableHiveSupport() .getOrCreate() - def getSession(): SparkSession = { + def getSession: SparkSession = { spark.sparkContext.setLogLevel("WARN") - return spark + spark } } diff --git a/src/test/scala/com/qubole/spark/hiveacid/WriteSuite.scala b/src/test/scala/com/qubole/spark/hiveacid/WriteSuite.scala new file mode 100644 index 0000000..48ae2c9 --- /dev/null +++ b/src/test/scala/com/qubole/spark/hiveacid/WriteSuite.scala @@ -0,0 +1,137 @@ +/* + * Copyright 2019 Qubole, Inc. All rights reserved. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid + + +import org.apache.log4j.{Level, LogManager, Logger} +import org.scalatest._ + +import scala.util.control.NonFatal + +class WriteSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { + + val log: Logger = LogManager.getLogger(this.getClass) + log.setLevel(Level.INFO) + + var helper: TestHelper = _ + val isDebug = true + + val DEFAULT_DBNAME = "HiveTestDB" + val defaultPred = " intCol < 5 " + val cols: Map[String, String] = Map( + ("intCol","int"), + ("doubleCol","double"), + ("floatCol","float"), + ("booleanCol","boolean") + // TODO: Requires spark.sql.hive.convertMetastoreOrc=false to run + // ("dateCol","date") + ) + + override def beforeAll() { + try { + + helper = new TestHelper + if (isDebug) { + log.setLevel(Level.DEBUG) + } + helper.init(isDebug) + + // DB + helper.hiveExecute("DROP DATABASE IF EXISTS "+ DEFAULT_DBNAME +" CASCADE") + helper.hiveExecute("CREATE DATABASE "+ DEFAULT_DBNAME) + } catch { + case NonFatal(e) => log.info("failed " + e) + } + } + + override protected def afterAll(): Unit = { + helper.destroy() + } + + + // Test Run + insertIntoOverwriteTestForFullAcidTables(Table.allFullAcidTypes()) + + // TODO: Currently requires compatibility check to be disabled in HMS to run clean + // hive.metastore.client.capability.check=false + // insertIntoOverwriteTestForInsertOnlyTables(Table.allInsertOnlyTypes()) + + // Insert Into/Overwrite test for full acid tables + def insertIntoOverwriteTestForFullAcidTables(tTypes: List[(String,Boolean)]): Unit = { + tTypes.foreach { case (tType, isPartitioned) => + val tableNameHive = "tHive" + val tableNameSpark = "tSpark" + val testName = s"Simple InsertInto Test for $tableNameHive/$tableNameSpark type $tType" + test(testName) { + val tableHive = new Table(DEFAULT_DBNAME, tableNameHive, cols, tType, isPartitioned) + val tableSpark = new Table(DEFAULT_DBNAME, tableNameSpark, cols, tType, isPartitioned) + def code(): Unit = { + helper.recreate(tableHive) + helper.recreate(tableSpark) + + // Insert into rows in both tables from Hive and Spark + helper.hiveExecute(tableHive.insertIntoHiveTableKeyRange(11, 20)) + helper.sparkSQL(tableSpark.insertIntoSparkTableKeyRange(11, 20)) + var expectedRows = 10 + helper.compareTwoTablesViaHive(tableHive, tableSpark, "After Insert Into", expectedRows) + helper.compareTwoTablesViaSpark(tableHive, tableSpark, "After Insert Into", expectedRows) + + // Insert overwrite rows in both tables from Hive and Spark + helper.hiveExecute(tableHive.insertOverwriteHiveTableKeyRange(16, 25)) + helper.sparkSQL(tableSpark.insertOverwriteSparkTableKeyRange(16, 25)) + expectedRows = if (tableHive.isPartitioned) 15 else 10 + helper.compareTwoTablesViaHive(tableHive, tableSpark, "After Insert Overwrite", expectedRows) + helper.compareTwoTablesViaSpark(tableHive, tableSpark, "After Insert Overwrite", expectedRows) + + // Insert overwrite rows in both tables - add rows in hive table from spark and vice versa + helper.hiveExecute(tableSpark.insertOverwriteHiveTableKeyRange(24, 27)) + helper.sparkSQL(tableHive.insertOverwriteSparkTableKeyRange(24, 27)) + expectedRows = if (tableHive.isPartitioned) expectedRows + 2 else 4 + helper.compareTwoTablesViaHive(tableHive, tableSpark, "After Insert Overwrite", expectedRows) + helper.compareTwoTablesViaSpark(tableHive, tableSpark, "After Insert Overwrite", expectedRows) + + // Insert into rows in both tables - add rows in hive table from spark and vice versa + helper.hiveExecute(tableSpark.insertIntoHiveTableKeyRange(24, 27)) + helper.sparkSQL(tableHive.insertIntoSparkTableKeyRange(24, 27)) + expectedRows = expectedRows + 4 + helper.compareTwoTablesViaHive(tableHive, tableSpark, "After Insert Into", expectedRows) + helper.compareTwoTablesViaSpark(tableHive, tableSpark, "After Insert Into", expectedRows) + + } + helper.myRun(testName, code) + } + } + } + + def insertIntoOverwriteTestForInsertOnlyTables(tTypes: List[(String,Boolean)]): Unit = { + tTypes.foreach { case (tType, isPartitioned) => + val tableNameSpark = "tSpark" + val testName = s"Simple InsertInto Test for $tableNameSpark type $tType" + test(testName) { + val tableSpark = new Table(DEFAULT_DBNAME, tableNameSpark, cols, tType, isPartitioned) + def code() = { + helper.recreate(tableSpark) + } + helper.myRun(testName, code) + } + } + } + +} diff --git a/version.sbt b/version.sbt index 9f0e5a5..d3a524b 100644 --- a/version.sbt +++ b/version.sbt @@ -1 +1 @@ -version in ThisBuild := "0.4.0" +version in ThisBuild := "0.4.4"