Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support writing with functions in distribute/partition expressions #253

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
fcdf0f0
Spark 3.4: Support distribute by any predefined transform
Yxang May 17, 2023
e52b714
Spark 3.4: add udf: years, days, hours, murmurHash2 and murmurHash3. …
Yxang May 18, 2023
ff243b5
Spark 3.4: Fixup sharding key needs to be mod by cluster weight on lo…
Yxang May 19, 2023
a1d4dce
Scala 2.13: Fix Spark 3.4 compile issue
Yxang May 19, 2023
5ddb98f
Spark 3.4: Optimize sharding key handling when shuffle and sort
Yxang May 22, 2023
000638e
Spark 3.4: Optimize sharding key handling when shuffle and sort, appr…
Yxang May 22, 2023
59f3bed
Spark 3.4: Support variable length arguments for murmurHash (up to 5 …
Yxang May 23, 2023
af14b3a
Spark 3.4: add CityHash64
Yxang May 24, 2023
22f191a
Spark 3.4: Optimize sharding key handling when shuffle and sort, appr…
Yxang May 26, 2023
ea5ed0e
Spark 3.4 UDF: Amend input type, Make clickhouse function nullable, b…
Yxang May 26, 2023
a8bdcbf
Spark 3.4: Optimize sharding key handling when shuffle and sort, amen…
Yxang May 30, 2023
3dcdd81
Spark 3.4: Change ExprUtils to implicit
Yxang Jun 2, 2023
386ddb0
Spark 3.4 UDF: clickhouse code reference using tag from commit hash
Yxang Jun 25, 2023
286c21f
Spark 3.4 UDF: support varargs for Hash UDFs
Yxang Jun 26, 2023
e5809f7
Spark 3.4: refactor implicit into normal arg in ExprUtils
Yxang Jun 27, 2023
5ae4f3d
Spark 3.4: Cast type when calling projection, support recursive resolve
Yxang Jun 27, 2023
088bf3d
Spark 3.4 UDF: change pmod to mod because positiveModulo does not exi…
Yxang Jul 14, 2023
85a025f
Docs: add comment for modulo UDF
Yxang Jul 14, 2023
4e201d6
Spark 3.4: Adapt to hash function under clickhouse-core
Yxang Jul 25, 2023
085b3ad
fix style
Yxang Jul 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -97,4 +97,6 @@ case class ClusterSpec(
override def toString: String = s"cluster: $name, shards: [${shards.mkString(", ")}]"

@JsonIgnore @transient override lazy val nodes: Array[NodeSpec] = shards.sorted.flatMap(_.nodes)

def totalWeight: Int = shards.map(_.weight).sum
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.clickhouse.cluster

import org.apache.spark.sql.clickhouse.TestUtils.om
import xenon.clickhouse.func.{
ClickHouseXxHash64Shard,
CompositeFunctionRegistry,
DynamicFunctionRegistry,
StaticFunctionRegistry
}

import java.lang.{Long => JLong}

class ClickHouseClusterHashUDFSuite extends SparkClickHouseClusterTest {
// only for query function names
val dummyRegistry: CompositeFunctionRegistry = {
val dynamicFunctionRegistry = new DynamicFunctionRegistry
val xxHash64ShardFunc = new ClickHouseXxHash64Shard(Seq.empty)
dynamicFunctionRegistry.register("ck_xx_hash64_shard", xxHash64ShardFunc) // for compatible
dynamicFunctionRegistry.register("clickhouse_shard_xxHash64", xxHash64ShardFunc)
new CompositeFunctionRegistry(Array(StaticFunctionRegistry, dynamicFunctionRegistry))
}

def runTest(funcSparkName: String, funcCkName: String, stringVal: String): Unit = {
val sparkResult = spark.sql(
s"""SELECT
| $funcSparkName($stringVal) AS hash_value
|""".stripMargin
).collect
assert(sparkResult.length == 1)
val sparkHashVal = sparkResult.head.getAs[Long]("hash_value")

val clickhouseResultJsonStr = runClickHouseSQL(
s"""SELECT
| $funcCkName($stringVal) AS hash_value
|""".stripMargin
).head.getString(0)
val clickhouseResultJson = om.readTree(clickhouseResultJsonStr)
val clickhouseHashVal = JLong.parseUnsignedLong(clickhouseResultJson.get("hash_value").asText)
assert(
sparkHashVal == clickhouseHashVal,
s"ck_function: $funcCkName, spark_function: $funcSparkName, args: ($stringVal)"
)
}

Seq(
"clickhouse_xxHash64",
"clickhouse_murmurHash3_64",
"clickhouse_murmurHash3_32",
"clickhouse_murmurHash2_64",
"clickhouse_murmurHash2_32",
"clickhouse_cityHash64"
).foreach { funcSparkName =>
val funcCkName = dummyRegistry.getFuncMappingBySpark(funcSparkName)
test(s"UDF $funcSparkName") {
Seq(
"spark-clickhouse-connector",
"Apache Spark",
"ClickHouse",
"Yandex",
"热爱",
"在传统的行式数据库系统中,数据按如下顺序存储:",
"🇨🇳"
).foreach { rawStringVal =>
val stringVal = s"\'$rawStringVal\'"
runTest(funcSparkName, funcCkName, stringVal)
}
}
}

Seq(
"clickhouse_murmurHash3_64",
"clickhouse_murmurHash3_32",
"clickhouse_murmurHash2_64",
"clickhouse_murmurHash2_32",
"clickhouse_cityHash64"
).foreach { funcSparkName =>
val funcCkName = dummyRegistry.getFuncMappingBySpark(funcSparkName)
test(s"UDF $funcSparkName multiple args") {
val strings = Seq(
"\'spark-clickhouse-connector\'",
"\'Apache Spark\'",
"\'ClickHouse\'",
"\'Yandex\'",
"\'热爱\'",
"\'在传统的行式数据库系统中,数据按如下顺序存储:\'",
"\'🇨🇳\'"
)
val test_5 = strings.combinations(5)
test_5.foreach { seq =>
val stringVal = seq.mkString(", ")
runTest(funcSparkName, funcCkName, stringVal)
}
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.clickhouse.cluster

import org.apache.spark.SparkConf
import org.apache.spark.sql.Row

class ClusterShardByTransformSuite extends SparkClickHouseClusterTest {
override protected def sparkConf: SparkConf = {
val _conf = super.sparkConf
.set("spark.clickhouse.write.distributed.convertLocal", "true")
_conf
}

def runTest(func_name: String, func_args: Array[String]): Unit = {
val func_expr = s"$func_name(${func_args.mkString(",")})"
val cluster = "single_replica"
val db = s"db_${func_name}_shard_transform"
val tbl_dist = s"tbl_${func_name}_shard"
val tbl_local = s"${tbl_dist}_local"

try {
runClickHouseSQL(s"CREATE DATABASE IF NOT EXISTS $db ON CLUSTER $cluster")

spark.sql(
s"""CREATE TABLE $db.$tbl_local (
| create_time TIMESTAMP NOT NULL,
| create_date DATE NOT NULL,
| value STRING NOT NULL
|) USING ClickHouse
|TBLPROPERTIES (
| cluster = '$cluster',
| engine = 'MergeTree()',
| order_by = 'create_time'
|)
|""".stripMargin
)

runClickHouseSQL(
s"""CREATE TABLE $db.$tbl_dist ON CLUSTER $cluster
|AS $db.$tbl_local
|ENGINE = Distributed($cluster, '$db', '$tbl_local', $func_expr)
|""".stripMargin
)
spark.sql(
s"""INSERT INTO `$db`.`$tbl_dist`
|VALUES
| (timestamp'2021-01-01 10:10:10', date'2021-01-01', '1'),
| (timestamp'2022-02-02 11:10:10', date'2022-02-02', '2'),
| (timestamp'2023-03-03 12:10:10', date'2023-03-03', '3'),
| (timestamp'2024-04-04 13:10:10', date'2024-04-04', '4')
| AS tab(create_time, create_date, value)
|""".stripMargin
)
// check that data is indeed written
checkAnswer(
spark.table(s"$db.$tbl_dist").select("value").orderBy("create_time"),
Seq(Row("1"), Row("2"), Row("3"), Row("4"))
)

// check same data is sharded in the same server comparing native sharding
runClickHouseSQL(
s"""INSERT INTO `$db`.`$tbl_dist`
|VALUES
| (timestamp'2021-01-01 10:10:10', date'2021-01-01', '1'),
| (timestamp'2022-02-02 11:10:10', date'2022-02-02', '2'),
| (timestamp'2023-03-03 12:10:10', date'2023-03-03', '3'),
| (timestamp'2024-04-04 13:10:10', date'2024-04-04', '4')
|""".stripMargin
)
checkAnswer(
spark.table(s"$db.$tbl_local")
.groupBy("value").count().filter("count != 2"),
Seq.empty
)

} finally {
runClickHouseSQL(s"DROP TABLE IF EXISTS $db.$tbl_dist ON CLUSTER $cluster")
runClickHouseSQL(s"DROP TABLE IF EXISTS $db.$tbl_local ON CLUSTER $cluster")
runClickHouseSQL(s"DROP DATABASE IF EXISTS $db ON CLUSTER $cluster")
}
}

Seq(
// wait for SPARK-44180 to be fixed, then add implicit cast test cases
("toYear", Array("create_date")),
// ("toYear", Array("create_time")),
("toYYYYMM", Array("create_date")),
// ("toYYYYMM", Array("create_time")),
("toYYYYMMDD", Array("create_date")),
// ("toYYYYMMDD", Array("create_time")),
("toHour", Array("create_time")),
("xxHash64", Array("value")),
("murmurHash2_64", Array("value")),
("murmurHash2_32", Array("value")),
("murmurHash3_64", Array("value")),
("murmurHash3_32", Array("value")),
("cityHash64", Array("value")),
("modulo", Array("toYYYYMM(create_date)", "10"))
).foreach {
case (func_name: String, func_args: Array[String]) =>
test(s"shard by $func_name(${func_args.mkString(",")})")(runTest(func_name, func_args))
}

}
Original file line number Diff line number Diff line change
@@ -78,12 +78,8 @@ class WriteDistributionAndOrderingSuite extends SparkClickHouseSingleTest {
WRITE_REPARTITION_BY_PARTITION.key -> repartitionByPartition.toString,
WRITE_LOCAL_SORT_BY_KEY.key -> localSortByKey.toString
) {
if (!ignoreUnsupportedTransform && repartitionByPartition) {
intercept[AnalysisException](write())
} else {
write()
check()
}
write()
check()
}

Seq(true, false).foreach { ignoreUnsupportedTransform =>
Original file line number Diff line number Diff line change
@@ -15,106 +15,176 @@
package org.apache.spark.sql.clickhouse

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression}
import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException
import org.apache.spark.sql.catalyst.{expressions, SQLConfHelper}
import org.apache.spark.sql.catalyst.expressions.{
BoundReference,
Cast,
Expression,
TransformExpression,
V2ExpressionUtils
}
import org.apache.spark.sql.clickhouse.ClickHouseSQLConf.IGNORE_UNSUPPORTED_TRANSFORM
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.connector.expressions.Expressions._
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, _}
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}
import org.apache.spark.sql.types.{StructField, StructType}
import xenon.clickhouse.exception.CHClientException
import xenon.clickhouse.expr._
import xenon.clickhouse.func.FunctionRegistry
import xenon.clickhouse.spec.ClusterSpec

import scala.annotation.tailrec
import scala.util.{Failure, Success, Try}

object ExprUtils extends SQLConfHelper {
object ExprUtils extends SQLConfHelper with Serializable {

def toSparkPartitions(partitionKey: Option[List[Expr]]): Array[Transform] =
partitionKey.seq.flatten.flatten(toSparkTransformOpt).toArray
def toSparkPartitions(
partitionKey: Option[List[Expr]],
functionRegistry: FunctionRegistry
): Array[Transform] =
partitionKey.seq.flatten.flatten(toSparkTransformOpt(_, functionRegistry)).toArray

def toSparkSplits(shardingKey: Option[Expr], partitionKey: Option[List[Expr]]): Array[Transform] =
(shardingKey.seq ++ partitionKey.seq.flatten).flatten(toSparkTransformOpt).toArray
def toSparkSplits(
shardingKey: Option[Expr],
partitionKey: Option[List[Expr]],
functionRegistry: FunctionRegistry
): Array[Transform] =
(shardingKey.seq ++ partitionKey.seq.flatten).flatten(toSparkTransformOpt(_, functionRegistry)).toArray

def toSparkSortOrders(
shardingKeyIgnoreRand: Option[Expr],
partitionKey: Option[List[Expr]],
sortingKey: Option[List[OrderExpr]]
sortingKey: Option[List[OrderExpr]],
cluster: Option[ClusterSpec],
functionRegistry: FunctionRegistry
): Array[SortOrder] =
toSparkSplits(shardingKeyIgnoreRand, partitionKey).map(Expressions.sort(_, SortDirection.ASCENDING)) ++:
toSparkSplits(
shardingKeyIgnoreRand.map(k => ExprUtils.toSplitWithModulo(k, cluster.get.totalWeight)),
partitionKey,
functionRegistry
).map(Expressions.sort(_, SortDirection.ASCENDING)) ++:
sortingKey.seq.flatten.flatten { case OrderExpr(expr, asc, nullFirst) =>
val direction = if (asc) SortDirection.ASCENDING else SortDirection.DESCENDING
val nullOrder = if (nullFirst) NullOrdering.NULLS_FIRST else NullOrdering.NULLS_LAST
toSparkTransformOpt(expr).map(trans => Expressions.sort(trans, direction, nullOrder))
toSparkTransformOpt(expr, functionRegistry).map(trans =>
Expressions.sort(trans, direction, nullOrder)
)
}.toArray

@tailrec
def toCatalyst(v2Expr: V2Expression, fields: Array[StructField]): Expression =
private def loadV2FunctionOpt(
name: String,
args: Seq[Expression],
functionRegistry: FunctionRegistry
): Option[BoundFunction] = {
def loadFunction(ident: Identifier): UnboundFunction =
functionRegistry.load(ident.name).getOrElse(throw new NoSuchFunctionException(ident))
val inputType = StructType(args.zipWithIndex.map {
case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable)
})
try {
val unbound = loadFunction(Identifier.of(Array.empty, name))
Some(unbound.bind(inputType))
} catch {
case e: NoSuchFunctionException =>
throw e
case _: UnsupportedOperationException if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) =>
None
case e: UnsupportedOperationException =>
throw new AnalysisException(e.getMessage, cause = Some(e))
}
}

def resolveTransformCatalyst(
catalystExpr: Expression,
timeZoneId: Option[String] = None
): Expression = catalystExpr match {
case TransformExpression(function: ScalarFunction[_], args, _) =>
val resolvedArgs: Seq[Expression] = args.map(resolveTransformCatalyst(_, timeZoneId))
val castedArgs: Seq[Expression] = resolvedArgs.zip(function.inputTypes()).map {
case (arg, expectedType) if !arg.dataType.sameType(expectedType) => Cast(arg, expectedType, timeZoneId)
case (arg, _) => arg
}
V2ExpressionUtils.resolveScalarFunction(function, castedArgs)
case other => other
}

def toCatalyst(
v2Expr: V2Expression,
fields: Array[StructField],
functionRegistry: FunctionRegistry
): Expression =
v2Expr match {
case IdentityTransform(ref) => toCatalyst(ref, fields)
case IdentityTransform(ref) => toCatalyst(ref, fields, functionRegistry)
case ref: NamedReference if ref.fieldNames.length == 1 =>
val (field, ordinal) = fields
.zipWithIndex
.find { case (field, _) => field.name == ref.fieldNames.head }
.getOrElse(throw CHClientException(s"Invalid field reference: $ref"))
BoundReference(ordinal, field.dataType, field.nullable)
case t: Transform =>
val catalystArgs = t.arguments().map(toCatalyst(_, fields, functionRegistry))
loadV2FunctionOpt(t.name(), catalystArgs, functionRegistry)
.map(bound => TransformExpression(bound, catalystArgs)).getOrElse {
throw CHClientException(s"Unsupported expression: $v2Expr")
}
case literal: LiteralValue[Any] => expressions.Literal(literal.value)
case _ => throw CHClientException(
s"Unsupported V2 expression: $v2Expr, SPARK-33779: Spark 3.3 only support IdentityTransform"
s"Unsupported expression: $v2Expr"
)
}

def toSparkTransformOpt(expr: Expr): Option[Transform] = Try(toSparkTransform(expr)) match {
case Success(t) => Some(t)
case Failure(_) if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) => None
case Failure(rethrow) => throw new AnalysisException(rethrow.getMessage, cause = Some(rethrow))
}

// Some functions of ClickHouse which match Spark pre-defined Transforms
//
// toYear, YEAR - Converts a date or date with time to a UInt16 (AD)
// toYYYYMM - Converts a date or date with time to a UInt32 (YYYY*100 + MM)
// toYYYYMMDD - Converts a date or date with time to a UInt32 (YYYY*10000 + MM*100 + DD)
// toHour, HOUR - Converts a date with time to a UInt8 (0-23)
def toSparkTransformOpt(expr: Expr, functionRegistry: FunctionRegistry): Option[Transform] =
Try(toSparkExpression(expr, functionRegistry)) match {
// need this function because spark `Table`'s `partitioning` field should be `Transform`
case Success(t: Transform) => Some(t)
case Success(_) => None
case Failure(_) if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) => None
case Failure(rethrow) => throw new AnalysisException(rethrow.getMessage, cause = Some(rethrow))
}

def toSparkTransform(expr: Expr): Transform = expr match {
case FieldRef(col) => identity(col)
case FuncExpr("toYear", List(FieldRef(col))) => years(col)
case FuncExpr("YEAR", List(FieldRef(col))) => years(col)
case FuncExpr("toYYYYMM", List(FieldRef(col))) => months(col)
case FuncExpr("toYYYYMMDD", List(FieldRef(col))) => days(col)
case FuncExpr("toHour", List(FieldRef(col))) => hours(col)
case FuncExpr("HOUR", List(FieldRef(col))) => hours(col)
// TODO support arbitrary functions
// case FuncExpr("xxHash64", List(FieldRef(col))) => apply("ck_xx_hash64", column(col))
case FuncExpr("rand", Nil) => apply("rand")
case FuncExpr("toYYYYMMDD", List(FuncExpr("toDate", List(FieldRef(col))))) => identity(col)
case unsupported => throw CHClientException(s"Unsupported ClickHouse expression: $unsupported")
}
def toSparkExpression(expr: Expr, functionRegistry: FunctionRegistry): V2Expression =
expr match {
case FieldRef(col) => identity(col)
case StringLiteral(value) => literal(value)
case FuncExpr("rand", Nil) => apply("rand")
case FuncExpr("toYYYYMMDD", List(FuncExpr("toDate", List(FieldRef(col))))) => identity(col)
case FuncExpr(funName, args) if functionRegistry.getFuncMappingByCk.contains(funName) =>
apply(functionRegistry.getFuncMappingByCk(funName), args.map(toSparkExpression(_, functionRegistry)): _*)
case unsupported => throw CHClientException(s"Unsupported ClickHouse expression: $unsupported")
}

def toClickHouse(transform: Transform): Expr = transform match {
case YearsTransform(FieldReference(Seq(col))) => FuncExpr("toYear", List(FieldRef(col)))
case MonthsTransform(FieldReference(Seq(col))) => FuncExpr("toYYYYMM", List(FieldRef(col)))
case DaysTransform(FieldReference(Seq(col))) => FuncExpr("toYYYYMMDD", List(FieldRef(col)))
case HoursTransform(FieldReference(Seq(col))) => FuncExpr("toHour", List(FieldRef(col)))
def toClickHouse(
transform: Transform,
functionRegistry: FunctionRegistry
): Expr = transform match {
case IdentityTransform(fieldRefs) => FieldRef(fieldRefs.describe)
case ApplyTransform(name, args) => FuncExpr(name, args.map(arg => SQLExpr(arg.describe())).toList)
case ApplyTransform(name, args) if functionRegistry.getFuncMappingBySpark.contains(name) =>
FuncExpr(functionRegistry.getFuncMappingBySpark(name), args.map(arg => SQLExpr(arg.describe())).toList)
case bucket: BucketTransform => throw CHClientException(s"Bucket transform not support yet: $bucket")
case other: Transform => throw CHClientException(s"Unsupported transform: $other")
}

def inferTransformSchema(
primarySchema: StructType,
secondarySchema: StructType,
transform: Transform
transform: Transform,
functionRegistry: FunctionRegistry
): StructField = transform match {
case years: YearsTransform => StructField(years.toString, IntegerType)
case months: MonthsTransform => StructField(months.toString, IntegerType)
case days: DaysTransform => StructField(days.toString, IntegerType)
case hours: HoursTransform => StructField(hours.toString, IntegerType)
case IdentityTransform(FieldReference(Seq(col))) => primarySchema.find(_.name == col)
.orElse(secondarySchema.find(_.name == col))
.getOrElse(throw CHClientException(s"Invalid partition column: $col"))
case ckXxhHash64 @ ApplyTransform("ck_xx_hash64", _) => StructField(ckXxhHash64.toString, LongType)
case t @ ApplyTransform(transformName, _) if functionRegistry.load(transformName).isDefined =>
val resType =
functionRegistry.load(transformName).getOrElse(throw new NoSuchFunctionException(transformName)) match {
case f: ScalarFunction[_] => f.resultType()
case other => throw CHClientException(s"Unsupported function: $other")
}
StructField(t.toString, resType)
case bucket: BucketTransform => throw CHClientException(s"Bucket transform not support yet: $bucket")
case other: Transform => throw CHClientException(s"Unsupported transform: $other")
}

def toSplitWithModulo(shardingKey: Expr, weight: Int): FuncExpr =
FuncExpr("modulo", List(shardingKey, StringLiteral(weight.toString)))
}
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@ import xenon.clickhouse.Constants._
import xenon.clickhouse.client.NodeClient
import xenon.clickhouse.exception.CHClientException
import xenon.clickhouse.exception.ClickHouseErrCode._
import xenon.clickhouse.func.{FunctionRegistry, _}
import xenon.clickhouse.func.{ClickHouseXxHash64Shard, FunctionRegistry, _}
import xenon.clickhouse.spec._

import java.time.ZoneId
@@ -91,6 +91,7 @@ class ClickHouseCatalog extends TableCatalog

log.info(s"Detect ${clusterSpecs.size} ClickHouse clusters: ${clusterSpecs.map(_.name).mkString(",")}")
log.info(s"ClickHouse clusters' detail: $clusterSpecs")
log.info(s"functionRegistry: ${this.functionRegistry.list.mkString(",")}")
}

override def name(): String = catalogName
@@ -141,7 +142,8 @@ class ClickHouseCatalog extends TableCatalog
tableClusterSpec,
_tz,
tableSpec,
tableEngineSpec
tableEngineSpec,
functionRegistry
)
}

@@ -206,7 +208,7 @@ class ClickHouseCatalog extends TableCatalog

val partitionsClause = partitions match {
case transforms if transforms.nonEmpty =>
transforms.map(ExprUtils.toClickHouse(_).sql).mkString("PARTITION BY (", ", ", ")")
transforms.map(ExprUtils.toClickHouse(_, functionRegistry).sql).mkString("PARTITION BY (", ", ", ")")
case _ => ""
}

@@ -297,7 +299,7 @@ class ClickHouseCatalog extends TableCatalog
}
tableOpt match {
case None => false
case Some(ClickHouseTable(_, cluster, _, tableSpec, _)) =>
case Some(ClickHouseTable(_, cluster, _, tableSpec, _, _)) =>
val (db, tbl) = (tableSpec.database, tableSpec.name)
val isAtomic = loadNamespaceMetadata(Array(db)).get("engine").equalsIgnoreCase("atomic")
val syncClause = if (isAtomic) "SYNC" else ""
Original file line number Diff line number Diff line change
@@ -14,16 +14,12 @@

package xenon.clickhouse

import java.lang.{Integer => JInt, Long => JLong}
import java.time.{LocalDate, ZoneId}
import java.util
import scala.collection.JavaConverters._
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.clickhouse.{ExprUtils, ReadOptions, WriteOptions}
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
import org.apache.spark.sql.clickhouse.ClickHouseSQLConf.{READ_DISTRIBUTED_CONVERT_LOCAL, USE_NULLABLE_QUERY_SCHEMA}
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.clickhouse.{ExprUtils, ReadOptions, WriteOptions}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.connector.write.LogicalWriteInfo
@@ -34,16 +30,23 @@ import org.apache.spark.unsafe.types.UTF8String
import xenon.clickhouse.Utils._
import xenon.clickhouse.client.NodeClient
import xenon.clickhouse.expr.{Expr, OrderExpr}
import xenon.clickhouse.func.FunctionRegistry
import xenon.clickhouse.read.{ClickHouseMetadataColumn, ClickHouseScanBuilder, ScanJobDescription}
import xenon.clickhouse.spec._
import xenon.clickhouse.write.{ClickHouseWriteBuilder, WriteJobDescription}

import java.lang.{Integer => JInt, Long => JLong}
import java.time.{LocalDate, ZoneId}
import java.util
import scala.collection.JavaConverters._

case class ClickHouseTable(
node: NodeSpec,
cluster: Option[ClusterSpec],
implicit val tz: ZoneId,
spec: TableSpec,
engineSpec: TableEngineSpec
engineSpec: TableEngineSpec,
functionRegistry: FunctionRegistry
) extends Table
with SupportsRead
with SupportsWrite
@@ -130,10 +133,12 @@ case class ClickHouseTable(
private lazy val metadataSchema: StructType =
StructType(metadataColumns.map(_.asInstanceOf[ClickHouseMetadataColumn].toStructField))

override lazy val partitioning: Array[Transform] = ExprUtils.toSparkPartitions(partitionKey)
override lazy val partitioning: Array[Transform] = ExprUtils.toSparkPartitions(partitionKey, functionRegistry)

override lazy val partitionSchema: StructType = StructType(
partitioning.map(partTransform => ExprUtils.inferTransformSchema(schema, metadataSchema, partTransform))
partitioning.map(partTransform =>
ExprUtils.inferTransformSchema(schema, metadataSchema, partTransform, functionRegistry)
)
)

override lazy val properties: util.Map[String, String] = spec.toJavaMap
@@ -170,7 +175,8 @@ case class ClickHouseTable(
shardingKey = shardingKey,
partitionKey = partitionKey,
sortingKey = sortingKey,
writeOptions = new WriteOptions(info.options.asCaseSensitiveMap())
writeOptions = new WriteOptions(info.options.asCaseSensitiveMap()),
functionRegistry = functionRegistry
)

new ClickHouseWriteBuilder(writeJob)
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package xenon.clickhouse.func

import xenon.clickhouse.hash

object CityHash64 extends MultiStringArgsHash {
// https://github.com/ClickHouse/ClickHouse/blob/v23.5.3.24-stable/src/Functions/FunctionsHashing.h#L694

override protected def funcName: String = "clickhouse_cityHash64"
override val ckFuncNames: Array[String] = Array("cityHash64")

override def applyHash(input: Array[Any]): Long = hash.CityHash64(input)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package xenon.clickhouse.func

import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.types._

import java.time.LocalDate
import java.time.format.DateTimeFormatter

object Days extends UnboundFunction with ScalarFunction[Int] with ClickhouseEquivFunction {

override def name: String = "clickhouse_days"

override def canonicalName: String = s"clickhouse.$name"

override def toString: String = name

override val ckFuncNames: Array[String] = Array("toYYYYMMDD")

override def description: String = s"$name: (date: Date) => shard_num: int"

override def bind(inputType: StructType): BoundFunction = inputType.fields match {
case Array(StructField(_, DateType, _, _)) => this
// case Array(StructField(_, TimestampType, _, _)) | Array(StructField(_, TimestampNTZType, _, _)) => this
case _ => throw new UnsupportedOperationException(s"Expect 1 DATE argument. $description")
}

override def inputTypes: Array[DataType] = Array(DateType)

override def resultType: DataType = IntegerType

override def isResultNullable: Boolean = false

def invoke(days: Int): Int = {
val date = LocalDate.ofEpochDay(days)
val formatter = DateTimeFormatter.ofPattern("yyyyMMdd")
date.format(formatter).toInt
}
}
Original file line number Diff line number Diff line change
@@ -18,30 +18,62 @@ import org.apache.spark.sql.connector.catalog.functions.UnboundFunction

import scala.collection.mutable

trait FunctionRegistry {
trait FunctionRegistry extends Serializable {

def list: Array[String]

def load(name: String): Option[UnboundFunction]

def getFuncMappingBySpark: Map[String, String]

def getFuncMappingByCk: Map[String, String]
}

trait ClickhouseEquivFunction {
val ckFuncNames: Array[String]
}

class CompositeFunctionRegistry(registries: Array[FunctionRegistry]) extends FunctionRegistry {

override def list: Array[String] = registries.flatMap(_.list)

override def load(name: String): Option[UnboundFunction] = registries.flatMap(_.load(name)).headOption

override def getFuncMappingBySpark: Map[String, String] = registries.flatMap(_.getFuncMappingBySpark).toMap

override def getFuncMappingByCk: Map[String, String] = registries.flatMap(_.getFuncMappingByCk).toMap
}

object StaticFunctionRegistry extends FunctionRegistry {

private val functions = Map[String, UnboundFunction](
"ck_xx_hash64" -> ClickHouseXxHash64, // for compatible
"clickhouse_xxHash64" -> ClickHouseXxHash64
"clickhouse_xxHash64" -> ClickHouseXxHash64,
"clickhouse_murmurHash2_32" -> MurmurHash2_32,
"clickhouse_murmurHash2_64" -> MurmurHash2_64,
"clickhouse_murmurHash3_32" -> MurmurHash3_32,
"clickhouse_murmurHash3_64" -> MurmurHash3_64,
"clickhouse_cityHash64" -> CityHash64,
"clickhouse_years" -> Years,
"clickhouse_months" -> Months,
"clickhouse_days" -> Days,
"clickhouse_hours" -> Hours,
"sharding_mod" -> Mod
)

override def list: Array[String] = functions.keys.toArray

override def load(name: String): Option[UnboundFunction] = functions.get(name)

override val getFuncMappingBySpark: Map[String, String] =
functions.filter(_._2.isInstanceOf[ClickhouseEquivFunction]).flatMap { case (k, v) =>
v.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.map((k, _))
}

override val getFuncMappingByCk: Map[String, String] =
functions.filter(_._2.isInstanceOf[ClickhouseEquivFunction]).flatMap { case (k, v) =>
v.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.map((_, k))
}
}

class DynamicFunctionRegistry extends FunctionRegistry {
@@ -56,4 +88,14 @@ class DynamicFunctionRegistry extends FunctionRegistry {
override def list: Array[String] = functions.keys.toArray

override def load(name: String): Option[UnboundFunction] = functions.get(name)

override def getFuncMappingBySpark: Map[String, String] =
functions.filter(_._2.isInstanceOf[ClickhouseEquivFunction]).toMap.flatMap { case (k, v) =>
v.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.map((k, _))
}

override def getFuncMappingByCk: Map[String, String] =
functions.filter(_._2.isInstanceOf[ClickhouseEquivFunction]).toMap.flatMap { case (k, v) =>
v.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.map((_, k))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package xenon.clickhouse.func

import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.types._

import java.sql.Timestamp
import java.text.SimpleDateFormat

object Hours extends UnboundFunction with ScalarFunction[Int] with ClickhouseEquivFunction {

override def name: String = "clickhouse_hours"

override def canonicalName: String = s"clickhouse.$name"

override def toString: String = name

override val ckFuncNames: Array[String] = Array("toHour", "HOUR")

override def description: String = s"$name: (time: timestamp) => shard_num: int"

override def bind(inputType: StructType): BoundFunction = inputType.fields match {
case Array(StructField(_, TimestampType, _, _)) | Array(StructField(_, TimestampNTZType, _, _)) => this
case _ => throw new UnsupportedOperationException(s"Expect 1 TIMESTAMP argument. $description")
}

override def inputTypes: Array[DataType] = Array(TimestampType)

override def resultType: DataType = IntegerType

override def isResultNullable: Boolean = false

def invoke(time: Long): Int = {
val ts = new Timestamp(time / 1000)
val formatter: SimpleDateFormat = new SimpleDateFormat("hh")
formatter.format(ts).toInt
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package xenon.clickhouse.func

import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.types._

object Mod extends UnboundFunction with ScalarFunction[Long] with ClickhouseEquivFunction {

override def name: String = "sharding_mod"

override def canonicalName: String = s"clickhouse.$name"

override def toString: String = name

// remainder is not a Clickhouse function, but modulo will be parsed to remainder in the connector.
// Added remainder as a synonym.
override val ckFuncNames: Array[String] = Array("modulo", "remainder")

override def description: String = s"$name: (a: long, b: long) => mod: long"

override def bind(inputType: StructType): BoundFunction = inputType.fields match {
case Array(a, b) if
(a match {
case StructField(_, LongType, _, _) => true
case StructField(_, IntegerType, _, _) => true
case StructField(_, ShortType, _, _) => true
case StructField(_, ByteType, _, _) => true
case StructField(_, StringType, _, _) => true
case _ => false
}) &&
(b match {
case StructField(_, LongType, _, _) => true
case StructField(_, IntegerType, _, _) => true
case StructField(_, ShortType, _, _) => true
case StructField(_, ByteType, _, _) => true
case StructField(_, StringType, _, _) => true
case _ => false
}) =>
this
case _ => throw new UnsupportedOperationException(s"Expect 2 integer arguments. $description")
}

override def inputTypes: Array[DataType] = Array(LongType, LongType)

override def resultType: DataType = LongType

override def isResultNullable: Boolean = false

def invoke(a: Long, b: Long): Long = a % b
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package xenon.clickhouse.func

import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.types._

import java.time.LocalDate
import java.time.format.DateTimeFormatter

object Months extends UnboundFunction with ScalarFunction[Int] with ClickhouseEquivFunction {

override def name: String = "clickhouse_months"

override def canonicalName: String = s"clickhouse.$name"

override def toString: String = name

override val ckFuncNames: Array[String] = Array("toYYYYMM")

override def description: String = s"$name: (date: Date) => shard_num: int"

override def bind(inputType: StructType): BoundFunction = inputType.fields match {
case Array(StructField(_, DateType, _, _)) => this
// case Array(StructField(_, TimestampType, _, _)) | Array(StructField(_, TimestampNTZType, _, _)) => this
case _ => throw new UnsupportedOperationException(s"Expect 1 DATE argument. $description")
}

override def inputTypes: Array[DataType] = Array(DateType)

override def resultType: DataType = IntegerType

override def isResultNullable: Boolean = false

def invoke(days: Int): Int = {
val date = LocalDate.ofEpochDay(days)
val formatter = DateTimeFormatter.ofPattern("yyyyMM")
date.format(formatter).toInt
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package xenon.clickhouse.func

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

abstract class MultiStringArgsHash extends UnboundFunction with ClickhouseEquivFunction {

def applyHash(input: Array[Any]): Long

protected def funcName: String

override val ckFuncNames: Array[String]

override def description: String = s"$name: (value: string, ...) => hash_value: long"

private def isExceptedType(dt: DataType): Boolean =
dt.isInstanceOf[StringType]

final override def name: String = funcName

final override def bind(inputType: StructType): BoundFunction = {
val inputDataTypes = inputType.fields.map(_.dataType)
if (inputDataTypes.forall(isExceptedType)) {
// need to new a ScalarFunction instance for each bind,
// because we do not know the number of arguments in advance
new ScalarFunction[Long] {
override def inputTypes(): Array[DataType] = inputDataTypes
override def name: String = funcName
override def canonicalName: String = s"clickhouse.$name"
override def resultType: DataType = LongType
override def toString: String = name
override def produceResult(input: InternalRow): Long = {
val inputStrings: Array[Any] =
input.toSeq(Seq.fill(input.numFields)(StringType)).asInstanceOf[Seq[UTF8String]].toArray
.map(_.getBytes)
applyHash(inputStrings)
}
}
} else throw new UnsupportedOperationException(s"Expect multiple STRING argument. $description")

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package xenon.clickhouse.func

import xenon.clickhouse.hash
import xenon.clickhouse.hash.HashUtils

object MurmurHash2_64 extends MultiStringArgsHash {
// https://github.com/ClickHouse/ClickHouse/blob/v23.5.3.24-stable/src/Functions/FunctionsHashing.h#L460

override protected def funcName: String = "clickhouse_murmurHash2_64"
override val ckFuncNames: Array[String] = Array("murmurHash2_64")

override def applyHash(input: Array[Any]): Long = hash.Murmurhash2_64(input)
}

object MurmurHash2_32 extends MultiStringArgsHash {
// https://github.com/ClickHouse/ClickHouse/blob/v23.5.3.24-stable/src/Functions/FunctionsHashing.h#L519

override protected def funcName: String = "clickhouse_murmurHash2_32"
override val ckFuncNames: Array[String] = Array("murmurHash2_32")

override def applyHash(input: Array[Any]): Long = HashUtils.toUInt32(hash.Murmurhash2_32(input))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package xenon.clickhouse.func

import xenon.clickhouse.hash
import xenon.clickhouse.hash.HashUtils

object MurmurHash3_64 extends MultiStringArgsHash {
// https://github.com/ClickHouse/ClickHouse/blob/v23.5.3.24-stable/src/Functions/FunctionsHashing.h#L543

override protected def funcName: String = "clickhouse_murmurHash3_64"
override val ckFuncNames: Array[String] = Array("murmurHash3_64")

override def applyHash(input: Array[Any]): Long = hash.Murmurhash3_64(input)
}

object MurmurHash3_32 extends MultiStringArgsHash {
// https://github.com/ClickHouse/ClickHouse/blob/v23.5.3.24-stable/src/Functions/FunctionsHashing.h#L519

override protected def funcName: String = "clickhouse_murmurHash3_32"
override val ckFuncNames: Array[String] = Array("murmurHash3_32")

override def applyHash(input: Array[Any]): Long = HashUtils.toUInt32(hash.Murmurhash3_32(input))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package xenon.clickhouse.func

object Util {
def intHash64Impl(x: Long): Long =
// https://github.com/ClickHouse/ClickHouse/blob/v23.5.3.24-stable/src/Functions/FunctionsHashing.h#L140
intHash64(x ^ 0x4cf2d2baae6da887L)

def intHash64(l: Long): Long = {
// https://github.com/ClickHouse/ClickHouse/blob/v23.5.3.24-stable/src/Common/HashTable/Hash.h#L26
var x = l
x ^= x >>> 33;
x *= 0xff51afd7ed558ccdL;
x ^= x >>> 33;
x *= 0xc4ceb9fe1a85ec53L;
x ^= x >>> 33;
x
}

def int32Impl(x: Long): Int =
// https://github.com/ClickHouse/ClickHouse/blob/v23.5.3.24-stable/src/Functions/FunctionsHashing.h#L133
intHash32(x, 0x75d9543de018bf45L)

def intHash32(l: Long, salt: Long): Int = {
// https://github.com/ClickHouse/ClickHouse/blob/v23.5.3.24-stable/src/Common/HashTable/Hash.h#L502
var x = l

x ^= salt;
x = (~x) + (x << 18)
x = x ^ ((x >>> 31) | (x << 33))
x = x * 21
x = x ^ ((x >>> 11) | (x << 53))
x = x + (x << 6)
x = x ^ ((x >>> 22) | (x << 42))
x.toInt
}

def toUInt32Range(v: Long): Long = if (v < 0) v + (1L << 32) else v
}
Original file line number Diff line number Diff line change
@@ -26,12 +26,16 @@ import xenon.clickhouse.spec.{ClusterSpec, ShardUtils}
* select xxHash64(concat(project_id, toString(seq))
* }}}
*/
object ClickHouseXxHash64 extends UnboundFunction with ScalarFunction[Long] {
object ClickHouseXxHash64 extends UnboundFunction with ScalarFunction[Long] with ClickhouseEquivFunction {

override def name: String = "clickhouse_xxHash64"

override def canonicalName: String = s"clickhouse.$name"

override def toString: String = name

override val ckFuncNames: Array[String] = Array("xxHash64")

override def description: String = s"$name: (value: string) => hash_value: long"

override def bind(inputType: StructType): BoundFunction = inputType.fields match {
@@ -45,6 +49,7 @@ object ClickHouseXxHash64 extends UnboundFunction with ScalarFunction[Long] {

override def isResultNullable: Boolean = false

// ignore UInt64 vs Int64
def invoke(value: UTF8String): Long = XxHash64Function.hash(value, StringType, 0L)
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package xenon.clickhouse.func

import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.types._

import java.time.LocalDate
import java.time.format.DateTimeFormatter

object Years extends UnboundFunction with ScalarFunction[Int] with ClickhouseEquivFunction {

override def name: String = "clickhouse_years"

override def canonicalName: String = s"clickhouse.$name"

override def toString: String = name

override val ckFuncNames: Array[String] = Array("toYear", "YEAR")

override def description: String = s"$name: (date: Date) => shard_num: int"

override def bind(inputType: StructType): BoundFunction = inputType.fields match {
case Array(StructField(_, DateType, _, _)) => this
// case Array(StructField(_, TimestampType, _, _)) | Array(StructField(_, TimestampNTZType, _, _)) => this
case _ => throw new UnsupportedOperationException(s"Expect 1 DATE argument. $description")
}

override def inputTypes: Array[DataType] = Array(DateType)

override def resultType: DataType = IntegerType

override def isResultNullable: Boolean = false

def invoke(days: Int): Int = {
val date = LocalDate.ofEpochDay(days)
val formatter = DateTimeFormatter.ofPattern("yyyy")
date.format(formatter).toInt
}
}
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ package xenon.clickhouse.write
import com.clickhouse.client.ClickHouseProtocol
import com.clickhouse.data.ClickHouseCompression
import org.apache.commons.io.IOUtils
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, SafeProjection}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, SafeProjection, TransformExpression}
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.clickhouse.ExprUtils
import org.apache.spark.sql.connector.metric.CustomTaskMetric
@@ -56,7 +56,7 @@ abstract class ClickHouseWriter(writeJob: WriteJobDescription)
protected lazy val shardExpr: Option[Expression] = writeJob.sparkShardExpr match {
case None => None
case Some(v2Expr) =>
val catalystExpr = ExprUtils.toCatalyst(v2Expr, writeJob.dataSetSchema.fields)
val catalystExpr = ExprUtils.toCatalyst(v2Expr, writeJob.dataSetSchema.fields, writeJob.functionRegistry)
catalystExpr match {
case BoundReference(_, dataType, _)
if dataType.isInstanceOf[ByteType] // list all integral types here because we can not access `IntegralType`
@@ -66,6 +66,11 @@ abstract class ClickHouseWriter(writeJob: WriteJobDescription)
Some(catalystExpr)
case BoundReference(_, dataType, _) =>
throw CHClientException(s"Invalid data type of sharding field: $dataType")
case TransformExpression(function, _, _) =>
function.resultType() match {
case ByteType | ShortType | IntegerType | LongType => Some(catalystExpr)
case _ => throw CHClientException(s"Invalid data type of sharding field: ${function.resultType()}")
}
case unsupported: Expression =>
log.warn(s"Unsupported expression of sharding field: $unsupported")
None
@@ -74,7 +79,21 @@ abstract class ClickHouseWriter(writeJob: WriteJobDescription)

protected lazy val shardProjection: Option[expressions.Projection] = shardExpr
.filter(_ => writeJob.writeOptions.convertDistributedToLocal)
.map(expr => SafeProjection.create(Seq(expr)))
.flatMap {
case expr: BoundReference =>
Some(SafeProjection.create(Seq(expr)))
case expr @ TransformExpression(function, _, _) =>
// result type must be integer class
function.resultType() match {
case ByteType => classOf[Byte]
case ShortType => classOf[Short]
case IntegerType => classOf[Int]
case LongType => classOf[Long]
case _ => throw CHClientException(s"Invalid return data type for function ${function.name()}," +
s"sharding field: ${function.resultType()}")
}
Some(SafeProjection.create(Seq(ExprUtils.resolveTransformCatalyst(expr, Some(writeJob.tz.getId)))))
}

// put the node select strategy in executor side because we need to calculate shard and don't know the records
// util DataWriter#write(InternalRow) invoked.
@@ -107,6 +126,15 @@ abstract class ClickHouseWriter(writeJob: WriteJobDescription)
case _ => None
}
shardValue.map(value => ShardUtils.calcShard(writeJob.cluster.get, value).num)
case (Some(TransformExpression(function, _, _)), Some(projection)) =>
val shardValue = function.resultType() match {
case ByteType => Some(projection(record).getByte(0).toLong)
case ShortType => Some(projection(record).getShort(0).toLong)
case IntegerType => Some(projection(record).getInt(0).toLong)
case LongType => Some(projection(record).getLong(0))
case _ => None
}
shardValue.map(value => ShardUtils.calcShard(writeJob.cluster.get, value).num)
case _ => None
}

Original file line number Diff line number Diff line change
@@ -15,11 +15,11 @@
package xenon.clickhouse.write

import java.time.ZoneId

import org.apache.spark.sql.clickhouse.{ExprUtils, WriteOptions}
import org.apache.spark.sql.connector.expressions.{Expression, SortOrder, Transform}
import org.apache.spark.sql.types.StructType
import xenon.clickhouse.expr.{Expr, FuncExpr, OrderExpr}
import xenon.clickhouse.func.FunctionRegistry
import xenon.clickhouse.spec._

case class WriteJobDescription(
@@ -37,7 +37,8 @@ case class WriteJobDescription(
shardingKey: Option[Expr],
partitionKey: Option[List[Expr]],
sortingKey: Option[List[OrderExpr]],
writeOptions: WriteOptions
writeOptions: WriteOptions,
functionRegistry: FunctionRegistry
) {

def targetDatabase(convert2Local: Boolean): String = tableEngineSpec match {
@@ -56,20 +57,34 @@ case class WriteJobDescription(
}

def sparkShardExpr: Option[Expression] = shardingKeyIgnoreRand match {
case Some(expr) => ExprUtils.toSparkTransformOpt(expr)
case Some(expr) => ExprUtils.toSparkTransformOpt(expr, functionRegistry)
case _ => None
}

def sparkSplits: Array[Transform] =
// Pmod by total weight * constant. Note that this key will be further hashed by spark. Reasons of doing this:
// - Enlarged range of modulo to avoid hash collision of small number of shards, hence mitigate data skew caused
// by this.
// - Still distribute data from one shard to only a subset of executors. If we do not apply modulo (instead we
// need to apply module during sorting in `toSparkSortOrders`), data belongs to shard 1 will be sorted in the
// front for all tasks, resulting in instant high pressure for shard 1 when stage starts.
if (writeOptions.repartitionByPartition) {
ExprUtils.toSparkSplits(shardingKeyIgnoreRand, partitionKey)
ExprUtils.toSparkSplits(
shardingKeyIgnoreRand.map(k => ExprUtils.toSplitWithModulo(k, cluster.get.totalWeight * 5)),
partitionKey,
functionRegistry
)
} else {
ExprUtils.toSparkSplits(shardingKeyIgnoreRand, None)
ExprUtils.toSparkSplits(
shardingKeyIgnoreRand.map(k => ExprUtils.toSplitWithModulo(k, cluster.get.totalWeight * 5)),
None,
functionRegistry
)
}

def sparkSortOrders: Array[SortOrder] = {
val _partitionKey = if (writeOptions.localSortByPartition) partitionKey else None
val _sortingKey = if (writeOptions.localSortByKey) sortingKey else None
ExprUtils.toSparkSortOrders(shardingKeyIgnoreRand, _partitionKey, _sortingKey)
ExprUtils.toSparkSortOrders(shardingKeyIgnoreRand, _partitionKey, _sortingKey, cluster, functionRegistry)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.clickhouse

import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.scalatest.funsuite.AnyFunSuite
import xenon.clickhouse.ClickHouseHelper
import xenon.clickhouse.func.{
ClickHouseXxHash64,
ClickhouseEquivFunction,
CompositeFunctionRegistry,
DynamicFunctionRegistry,
StaticFunctionRegistry
}

import scala.collection.JavaConverters._

class FunctionRegistrySuite extends AnyFunSuite {

val staticFunctionRegistry: StaticFunctionRegistry.type = StaticFunctionRegistry
val dynamicFunctionRegistry = new DynamicFunctionRegistry
dynamicFunctionRegistry.register("ck_xx_hash64", ClickHouseXxHash64)
dynamicFunctionRegistry.register("clickhouse_xxHash64", ClickHouseXxHash64)

test("check StaticFunctionRegistry mappings") {
assert(staticFunctionRegistry.getFuncMappingBySpark.forall { case (k, v) =>
staticFunctionRegistry.load(k).get.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.contains(v)
})
assert(staticFunctionRegistry.getFuncMappingByCk.forall { case (k, v) =>
staticFunctionRegistry.load(v).get.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.contains(k)
})
}

test("check DynamicFunctionRegistry mappings") {
assert(dynamicFunctionRegistry.getFuncMappingBySpark.forall { case (k, v) =>
dynamicFunctionRegistry.load(k).get.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.contains(v)
})
assert(dynamicFunctionRegistry.getFuncMappingByCk.forall { case (k, v) =>
dynamicFunctionRegistry.load(v).get.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.contains(k)
})
}

test("check CompositeFunctionRegistry mappings") {
val compositeFunctionRegistry =
new CompositeFunctionRegistry(Array(staticFunctionRegistry, dynamicFunctionRegistry))
assert(compositeFunctionRegistry.getFuncMappingBySpark.forall { case (k, v) =>
compositeFunctionRegistry.load(k).get.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.contains(v)
})
assert(compositeFunctionRegistry.getFuncMappingByCk.forall { case (k, v) =>
compositeFunctionRegistry.load(v).get.asInstanceOf[ClickhouseEquivFunction].ckFuncNames.contains(k)
})
}
}