diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 7edb1f51f11b1..231c604b98bb5 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -211,27 +211,13 @@ class SparkSession private[sql] ( /** @inheritdoc */ @Experimental - def sql(sqlText: String, args: Array[_]): DataFrame = newDataFrame { builder => - // Send the SQL once to the server and then check the output. - val cmd = newCommand(b => - b.setSqlCommand( - proto.SqlCommand - .newBuilder() - .setSql(sqlText) - .addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava))) - val plan = proto.Plan.newBuilder().setCommand(cmd) - val responseIter = client.execute(plan.build()) - - try { - val response = responseIter - .find(_.hasSqlCommandResult) - .getOrElse(throw new RuntimeException("SQLCommandResult must be present")) - // Update the builder with the values from the result. - builder.mergeFrom(response.getSqlCommandResult.getRelation) - } finally { - // consume the rest of the iterator - responseIter.foreach(_ => ()) - } + def sql(sqlText: String, args: Array[_]): DataFrame = { + val sqlCommand = proto.SqlCommand + .newBuilder() + .setSql(sqlText) + .addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava) + .build() + sql(sqlCommand) } /** @inheritdoc */ @@ -242,28 +228,13 @@ class SparkSession private[sql] ( /** @inheritdoc */ @Experimental - override def sql(sqlText: String, args: java.util.Map[String, Any]): DataFrame = newDataFrame { - builder => - // Send the SQL once to the server and then check the output. - val cmd = newCommand(b => - b.setSqlCommand( - proto.SqlCommand - .newBuilder() - .setSql(sqlText) - .putAllNamedArguments(args.asScala.map { case (k, v) => (k, lit(v).expr) }.asJava))) - val plan = proto.Plan.newBuilder().setCommand(cmd) - val responseIter = client.execute(plan.build()) - - try { - val response = responseIter - .find(_.hasSqlCommandResult) - .getOrElse(throw new RuntimeException("SQLCommandResult must be present")) - // Update the builder with the values from the result. - builder.mergeFrom(response.getSqlCommandResult.getRelation) - } finally { - // consume the rest of the iterator - responseIter.foreach(_ => ()) - } + override def sql(sqlText: String, args: java.util.Map[String, Any]): DataFrame = { + val sqlCommand = proto.SqlCommand + .newBuilder() + .setSql(sqlText) + .putAllNamedArguments(args.asScala.map { case (k, v) => (k, lit(v).expr) }.asJava) + .build() + sql(sqlCommand) } /** @inheritdoc */ @@ -271,6 +242,24 @@ class SparkSession private[sql] ( sql(query, Array.empty) } + private def sql(sqlCommand: proto.SqlCommand): DataFrame = newDataFrame { builder => + // Send the SQL once to the server and then check the output. + val cmd = newCommand(b => b.setSqlCommand(sqlCommand)) + val plan = proto.Plan.newBuilder().setCommand(cmd) + val responseIter = client.execute(plan.build()) + + try { + val response = responseIter + .find(_.hasSqlCommandResult) + .getOrElse(throw new RuntimeException("SQLCommandResult must be present")) + // Update the builder with the values from the result. + builder.mergeFrom(response.getSqlCommandResult.getRelation) + } finally { + // consume the rest of the iterator + responseIter.foreach(_ => ()) + } + } + /** @inheritdoc */ def read: DataFrameReader = new DataFrameReader(this)