Skip to content

Commit

Permalink
[SPARK-50410][CONNECT] Refactor the sql function in SparkSession
Browse files Browse the repository at this point in the history
…to eliminate duplicate code

### What changes were proposed in this pull request?

There is duplicate code between

https://github.com/apache/spark/blob/7b4922ea90d19d7e0510a205740b4c150057e988/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala#L214-L235

and

https://github.com/apache/spark/blob/7b4922ea90d19d7e0510a205740b4c150057e988/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala#L245-L267

So this pr introduces a new `private` function named `sql` in `SparkSession`. This function takes the type of `proto.SqlCommand` as its input, and the above two functions are refactored to call the newly added function in order to remove duplicate code.

### Why are the changes needed?
To eliminate duplicate code.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Pass GitHub Actions

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#48955 from LuciferYang/ref-sql-function.

Authored-by: yangjie01 <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
LuciferYang authored and dongjoon-hyun committed Nov 26, 2024
1 parent 69d433b commit afb5d6f
Showing 1 changed file with 32 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -211,27 +211,13 @@ class SparkSession private[sql] (

/** @inheritdoc */
@Experimental
def sql(sqlText: String, args: Array[_]): DataFrame = newDataFrame { builder =>
// Send the SQL once to the server and then check the output.
val cmd = newCommand(b =>
b.setSqlCommand(
proto.SqlCommand
.newBuilder()
.setSql(sqlText)
.addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava)))
val plan = proto.Plan.newBuilder().setCommand(cmd)
val responseIter = client.execute(plan.build())

try {
val response = responseIter
.find(_.hasSqlCommandResult)
.getOrElse(throw new RuntimeException("SQLCommandResult must be present"))
// Update the builder with the values from the result.
builder.mergeFrom(response.getSqlCommandResult.getRelation)
} finally {
// consume the rest of the iterator
responseIter.foreach(_ => ())
}
def sql(sqlText: String, args: Array[_]): DataFrame = {
val sqlCommand = proto.SqlCommand
.newBuilder()
.setSql(sqlText)
.addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava)
.build()
sql(sqlCommand)
}

/** @inheritdoc */
Expand All @@ -242,35 +228,38 @@ class SparkSession private[sql] (

/** @inheritdoc */
@Experimental
override def sql(sqlText: String, args: java.util.Map[String, Any]): DataFrame = newDataFrame {
builder =>
// Send the SQL once to the server and then check the output.
val cmd = newCommand(b =>
b.setSqlCommand(
proto.SqlCommand
.newBuilder()
.setSql(sqlText)
.putAllNamedArguments(args.asScala.map { case (k, v) => (k, lit(v).expr) }.asJava)))
val plan = proto.Plan.newBuilder().setCommand(cmd)
val responseIter = client.execute(plan.build())

try {
val response = responseIter
.find(_.hasSqlCommandResult)
.getOrElse(throw new RuntimeException("SQLCommandResult must be present"))
// Update the builder with the values from the result.
builder.mergeFrom(response.getSqlCommandResult.getRelation)
} finally {
// consume the rest of the iterator
responseIter.foreach(_ => ())
}
override def sql(sqlText: String, args: java.util.Map[String, Any]): DataFrame = {
val sqlCommand = proto.SqlCommand
.newBuilder()
.setSql(sqlText)
.putAllNamedArguments(args.asScala.map { case (k, v) => (k, lit(v).expr) }.asJava)
.build()
sql(sqlCommand)
}

/** @inheritdoc */
override def sql(query: String): DataFrame = {
sql(query, Array.empty)
}

private def sql(sqlCommand: proto.SqlCommand): DataFrame = newDataFrame { builder =>
// Send the SQL once to the server and then check the output.
val cmd = newCommand(b => b.setSqlCommand(sqlCommand))
val plan = proto.Plan.newBuilder().setCommand(cmd)
val responseIter = client.execute(plan.build())

try {
val response = responseIter
.find(_.hasSqlCommandResult)
.getOrElse(throw new RuntimeException("SQLCommandResult must be present"))
// Update the builder with the values from the result.
builder.mergeFrom(response.getSqlCommandResult.getRelation)
} finally {
// consume the rest of the iterator
responseIter.foreach(_ => ())
}
}

/** @inheritdoc */
def read: DataFrameReader = new DataFrameReader(this)

Expand Down

0 comments on commit afb5d6f

Please sign in to comment.