Skip to content

Commit

Permalink
feat: introduce substrait-spark module (#271)
Browse files Browse the repository at this point in the history
Provides code to assist w/ translating Substrait to and from Spark

This module was originally part of the Gluten project, but it was removed from there.

For more details see:
apache/incubator-gluten#4609
  • Loading branch information
andrew-coleman authored Jun 26, 2024
1 parent 8a8253e commit 8537dca
Show file tree
Hide file tree
Showing 36 changed files with 3,597 additions and 1 deletion.
2 changes: 2 additions & 0 deletions gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ jackson.version=2.16.1
junit.version=5.8.1
protobuf.version=3.25.3
slf4j.version=2.0.13
sparkbundle.version=3.4
spark.version=3.4.2

#version that is going to be updated automatically by releases
version = 0.34.0
Expand Down
2 changes: 1 addition & 1 deletion settings.gradle.kts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
rootProject.name = "substrait"

include("bom", "core", "isthmus", "isthmus-cli")
include("bom", "core", "isthmus", "isthmus-cli", "spark")

pluginManagement {
plugins {
Expand Down
113 changes: 113 additions & 0 deletions spark/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
plugins {
`maven-publish`
id("java")
id("scala")
id("idea")
id("com.diffplug.spotless") version "6.11.0"
signing
}

publishing {
publications {
create<MavenPublication>("maven-publish") {
from(components["java"])

pom {
name.set("Substrait Java")
description.set(
"Create a well-defined, cross-language specification for data compute operations"
)
url.set("https://github.com/substrait-io/substrait-java")
licenses {
license {
name.set("The Apache License, Version 2.0")
url.set("http://www.apache.org/licenses/LICENSE-2.0.txt")
}
}
developers {
developer {
// TBD Get the list of
}
}
scm {
connection.set("scm:git:git://github.com:substrait-io/substrait-java.git")
developerConnection.set("scm:git:ssh://github.com:substrait-io/substrait-java")
url.set("https://github.com/substrait-io/substrait-java/")
}
}
}
}
repositories {
maven {
name = "local"
val releasesRepoUrl = layout.buildDirectory.dir("repos/releases")
val snapshotsRepoUrl = layout.buildDirectory.dir("repos/snapshots")
url = uri(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl)
}
}
}

signing {
setRequired({ gradle.taskGraph.hasTask("publishToSonatype") })
val signingKeyId =
System.getenv("SIGNING_KEY_ID").takeUnless { it.isNullOrEmpty() }
?: extra["SIGNING_KEY_ID"].toString()
val signingPassword =
System.getenv("SIGNING_PASSWORD").takeUnless { it.isNullOrEmpty() }
?: extra["SIGNING_PASSWORD"].toString()
val signingKey =
System.getenv("SIGNING_KEY").takeUnless { it.isNullOrEmpty() }
?: extra["SIGNING_KEY"].toString()
useInMemoryPgpKeys(signingKeyId, signingKey, signingPassword)
sign(publishing.publications["maven-publish"])
}

configurations.all {
if (name.startsWith("incrementalScalaAnalysis")) {
setExtendsFrom(emptyList())
}
}

java {
toolchain { languageVersion.set(JavaLanguageVersion.of(17)) }
withJavadocJar()
withSourcesJar()
}

tasks.withType<ScalaCompile>() {
targetCompatibility = ""
scalaCompileOptions.additionalParameters = listOf("-release:17")
}

var SLF4J_VERSION = properties.get("slf4j.version")
var SPARKBUNDLE_VERSION = properties.get("sparkbundle.version")
var SPARK_VERSION = properties.get("spark.version")

sourceSets {
main { scala { setSrcDirs(listOf("src/main/scala", "src/main/spark-${SPARKBUNDLE_VERSION}")) } }
test { scala { setSrcDirs(listOf("src/test/scala", "src/test/spark-3.2", "src/main/scala")) } }
}

dependencies {
implementation(project(":core"))
implementation("org.scala-lang:scala-library:2.12.16")
implementation("org.apache.spark:spark-core_2.12:${SPARK_VERSION}")
implementation("org.apache.spark:spark-sql_2.12:${SPARK_VERSION}")
implementation("org.apache.spark:spark-catalyst_2.12:${SPARK_VERSION}")
implementation("org.slf4j:slf4j-api:${SLF4J_VERSION}")

testImplementation("org.scalatest:scalatest_2.12:3.2.18")
testRuntimeOnly("org.junit.platform:junit-platform-engine:1.10.0")
testRuntimeOnly("org.junit.platform:junit-platform-launcher:1.10.0")
testRuntimeOnly("org.scalatestplus:junit-5-10_2.12:3.2.18.0")
testImplementation("org.apache.spark:spark-core_2.12:${SPARK_VERSION}:tests")
testImplementation("org.apache.spark:spark-sql_2.12:${SPARK_VERSION}:tests")
testImplementation("org.apache.spark:spark-catalyst_2.12:${SPARK_VERSION}:tests")
}

tasks {
test {
dependsOn(":core:shadowJar")
useJUnitPlatform { includeEngines("scalatest") }
}
}
34 changes: 34 additions & 0 deletions spark/src/main/resources/spark.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
%YAML 1.2
---
scalar_functions:
-
name: year
description: Returns the year component of the date/timestamp
impls:
- args:
- value: date
return: i32
-
name: unscaled
description: >-
Return the unscaled Long value of a Decimal, assuming it fits in a Long.
Note: this expression is internal and created only by the optimizer,
we don't need to do type check for it.
impls:
- args:
- value: DECIMAL<P,S>
return: i64
75 changes: 75 additions & 0 deletions spark/src/main/scala/io/substrait/debug/ExpressionToString.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.substrait.debug

import io.substrait.spark.DefaultExpressionVisitor

import org.apache.spark.sql.catalyst.util.DateTimeUtils

import io.substrait.expression.{Expression, FieldReference}
import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, I32Literal, StrLiteral}
import io.substrait.function.ToTypeString
import io.substrait.util.DecimalUtil

import scala.collection.JavaConverters.asScalaBufferConverter

class ExpressionToString extends DefaultExpressionVisitor[String] {

override def visit(expr: DecimalLiteral): String = {
val value = expr.value.toByteArray
val decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale, 16)
decimal.toString
}

override def visit(expr: StrLiteral): String = {
expr.value()
}

override def visit(expr: I32Literal): String = {
expr.value().toString
}

override def visit(expr: DateLiteral): String = {
DateTimeUtils.toJavaDate(expr.value()).toString
}

override def visit(expr: FieldReference): String = {
withFieldReference(expr)(i => "$" + i.toString)
}

override def visit(expr: Expression.SingleOrList): String = {
expr.toString
}

override def visit(expr: Expression.ScalarFunctionInvocation): String = {
val args = expr
.arguments()
.asScala
.zipWithIndex
.map {
case (arg, i) =>
arg.accept(expr.declaration(), i, this)
}
.mkString(",")

s"${expr.declaration().key()}[${expr.outputType().accept(ToTypeString.INSTANCE)}]($args)"
}

override def visit(expr: Expression.UserDefinedLiteral): String = {
expr.toString
}
}
Loading

0 comments on commit 8537dca

Please sign in to comment.