From afb5d6f23abfae8950068f3cf5460ca3913a9742 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Mon, 25 Nov 2024 21:40:07 -0800 Subject: [PATCH] [SPARK-50410][CONNECT] Refactor the `sql` function in `SparkSession` to eliminate duplicate code ### What changes were proposed in this pull request? There is duplicate code between https://github.com/apache/spark/blob/7b4922ea90d19d7e0510a205740b4c150057e988/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala#L214-L235 and https://github.com/apache/spark/blob/7b4922ea90d19d7e0510a205740b4c150057e988/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala#L245-L267 So this pr introduces a new `private` function named `sql` in `SparkSession`. This function takes the type of `proto.SqlCommand` as its input, and the above two functions are refactored to call the newly added function in order to remove duplicate code. ### Why are the changes needed? To eliminate duplicate code. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #48955 from LuciferYang/ref-sql-function. Authored-by: yangjie01 Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/sql/SparkSession.scala | 75 ++++++++----------- 1 file changed, 32 insertions(+), 43 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 7edb1f51f11b1..231c604b98bb5 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -211,27 +211,13 @@ class SparkSession private[sql] ( /** @inheritdoc */ @Experimental - def sql(sqlText: String, args: Array[_]): DataFrame = newDataFrame { builder => - // Send the SQL once to the server and then check the output. - val cmd = newCommand(b => - b.setSqlCommand( - proto.SqlCommand - .newBuilder() - .setSql(sqlText) - .addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava))) - val plan = proto.Plan.newBuilder().setCommand(cmd) - val responseIter = client.execute(plan.build()) - - try { - val response = responseIter - .find(_.hasSqlCommandResult) - .getOrElse(throw new RuntimeException("SQLCommandResult must be present")) - // Update the builder with the values from the result. - builder.mergeFrom(response.getSqlCommandResult.getRelation) - } finally { - // consume the rest of the iterator - responseIter.foreach(_ => ()) - } + def sql(sqlText: String, args: Array[_]): DataFrame = { + val sqlCommand = proto.SqlCommand + .newBuilder() + .setSql(sqlText) + .addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava) + .build() + sql(sqlCommand) } /** @inheritdoc */ @@ -242,28 +228,13 @@ class SparkSession private[sql] ( /** @inheritdoc */ @Experimental - override def sql(sqlText: String, args: java.util.Map[String, Any]): DataFrame = newDataFrame { - builder => - // Send the SQL once to the server and then check the output. - val cmd = newCommand(b => - b.setSqlCommand( - proto.SqlCommand - .newBuilder() - .setSql(sqlText) - .putAllNamedArguments(args.asScala.map { case (k, v) => (k, lit(v).expr) }.asJava))) - val plan = proto.Plan.newBuilder().setCommand(cmd) - val responseIter = client.execute(plan.build()) - - try { - val response = responseIter - .find(_.hasSqlCommandResult) - .getOrElse(throw new RuntimeException("SQLCommandResult must be present")) - // Update the builder with the values from the result. - builder.mergeFrom(response.getSqlCommandResult.getRelation) - } finally { - // consume the rest of the iterator - responseIter.foreach(_ => ()) - } + override def sql(sqlText: String, args: java.util.Map[String, Any]): DataFrame = { + val sqlCommand = proto.SqlCommand + .newBuilder() + .setSql(sqlText) + .putAllNamedArguments(args.asScala.map { case (k, v) => (k, lit(v).expr) }.asJava) + .build() + sql(sqlCommand) } /** @inheritdoc */ @@ -271,6 +242,24 @@ class SparkSession private[sql] ( sql(query, Array.empty) } + private def sql(sqlCommand: proto.SqlCommand): DataFrame = newDataFrame { builder => + // Send the SQL once to the server and then check the output. + val cmd = newCommand(b => b.setSqlCommand(sqlCommand)) + val plan = proto.Plan.newBuilder().setCommand(cmd) + val responseIter = client.execute(plan.build()) + + try { + val response = responseIter + .find(_.hasSqlCommandResult) + .getOrElse(throw new RuntimeException("SQLCommandResult must be present")) + // Update the builder with the values from the result. + builder.mergeFrom(response.getSqlCommandResult.getRelation) + } finally { + // consume the rest of the iterator + responseIter.foreach(_ => ()) + } + } + /** @inheritdoc */ def read: DataFrameReader = new DataFrameReader(this)