Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle multiple result ids #84

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ You can link against this library in your program at the following ways:
<version>1.1.3</version>
</dependency>
```
or _2.12 for Scala 2.12

### SBT Dependency
```
Expand Down
12 changes: 5 additions & 7 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ scalaVersion := "2.12.10"

crossScalaVersions := Seq("2.11.12", "2.12.10")

resolvers += Resolver.mavenLocal
resolvers += "sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots/"

libraryDependencies ++= Seq(
"com.force.api" % "force-wsc" % "40.0.0",
"com.force.api" % "force-partner-api" % "40.0.0",
"com.springml" % "salesforce-wave-api" % "1.0.10",
"com.force.api" % "force-wsc" % "52.2.0",
"com.force.api" % "force-partner-api" % "52.2.0",
"com.springml" % "salesforce-wave-api" % "1.0.8-uber-2",
"org.mockito" % "mockito-core" % "2.0.31-beta"
)

Expand All @@ -27,7 +28,6 @@ resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositori

resolvers += "sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots/"

resolvers += "Spark Package Main Repo" at "https://dl.bintray.com/spark-packages/maven"

libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.3" % "test"
libraryDependencies += "com.fasterxml.jackson.dataformat" % "jackson-dataformat-xml" % "2.4.4"
Expand All @@ -53,7 +53,7 @@ spDescription := """Spark Salesforce Wave Connector
| - Constructs Salesforce Wave dataset's metadata using schema present in dataframe
| - Can use custom metadata for constructing Salesforce Wave dataset's metadata""".stripMargin

// licenses += "Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0")
licenses += "Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0")

credentials += Credentials(Path.userHome / ".ivy2" / ".credentials")

Expand Down Expand Up @@ -86,5 +86,3 @@ pomExtra := (
<url>http://www.springml.com</url>
</developer>
</developers>)


3 changes: 2 additions & 1 deletion project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
resolvers += "bintray-spark-packages" at "https://dl.bintray.com/spark-packages/maven/"
// resolvers += "bintray-spark-packages" at "https://dl.bintray.com/spark-packages/maven/"
resolvers += "bintray-spark-packages" at "https://repos.spark-packages.org"

addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.6")
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.3")
Expand Down
163 changes: 137 additions & 26 deletions src/main/scala/com/springml/spark/salesforce/BulkRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ case class BulkRelation(
maxCharsPerColumn: Int) extends BaseRelation with TableScan {

import sqlContext.sparkSession.implicits._
import scala.collection.JavaConversions._

@transient lazy val logger: Logger = Logger.getLogger(classOf[BulkRelation])

def buildScan() = records.rdd

Expand All @@ -51,57 +54,165 @@ case class BulkRelation(
}

lazy val records: DataFrame = {
val inputJobInfo = new JobInfo("CSV", sfObject, "query")
val inputJobInfo = new JobInfo("CSV", sfObject, "queryAll")
val jobInfo = bulkAPI.createJob(inputJobInfo, customHeaders.asJava)
val jobId = jobInfo.getId
logger.error(">>> Obtained jobId: " + jobId)

val batchInfo = bulkAPI.addBatch(jobId, query)
logger.error(">>> Obtained batchInfo: " + batchInfo)

if (awaitJobCompleted(jobId)) {
bulkAPI.closeJob(jobId)

val batchInfoList = bulkAPI.getBatchInfoList(jobId)
val batchInfos = batchInfoList.getBatchInfo().asScala.toList

logger.error(">>> Obtained batchInfos: " + batchInfos)
logger.error(">>>>>> Obtained batchInfos.size: " + batchInfos.size)

val completedBatchInfos = batchInfos.filter(batchInfo => batchInfo.getState().equals("Completed"))
val completedBatchInfoIds = completedBatchInfos.map(batchInfo => batchInfo.getId)

val fetchBatchInfo = (batchInfoId: String) => {
val resultIds = bulkAPI.getBatchResultIds(jobId, batchInfoId)
logger.error(">>> Obtained completedBatchInfoIds: " + completedBatchInfoIds)
logger.error(">>> Obtained completedBatchInfoIds.size: " + completedBatchInfoIds.size)

val result = bulkAPI.getBatchResult(jobId, batchInfoId, resultIds.get(resultIds.size() - 1))
def splitCsvByRows(csvString: String): Seq[String] = {
// The CsvParser interface only interacts with IO, so StringReader and StringWriter
val inputReader = new StringReader(csvString)

// Use Csv parser to split CSV by rows to cover edge cases (ex. escaped characters, new line within string, etc)
def splitCsvByRows(csvString: String): Seq[String] = {
// The CsvParser interface only interacts with IO, so StringReader and StringWriter
val inputReader = new StringReader(csvString)
val parserSettings = new CsvParserSettings()
parserSettings.setLineSeparatorDetectionEnabled(true)
parserSettings.getFormat.setNormalizedNewline(' ')
parserSettings.setMaxCharsPerColumn(maxCharsPerColumn)

val readerParser = new CsvParser(parserSettings)
val parsedInput = readerParser.parseAll(inputReader).asScala

val outputWriter = new StringWriter()

val writerSettings = new CsvWriterSettings()
writerSettings.setQuoteAllFields(true)
writerSettings.setQuoteEscapingEnabled(true)

val writer = new CsvWriter(outputWriter, writerSettings)
parsedInput.foreach {
writer.writeRow(_)
}

outputWriter.toString.lines.toList
}

val fetchAllResults = (resultId: String, batchInfoId: String) => {
logger.error("Getting Result for ResultId: " + resultId)
val result = bulkAPI.getBatchResult(jobId, batchInfoId, resultId)

val splitRows = splitCsvByRows(result)

val parserSettings = new CsvParserSettings()
parserSettings.setLineSeparatorDetectionEnabled(true)
parserSettings.getFormat.setNormalizedNewline(' ')
parserSettings.setMaxCharsPerColumn(maxCharsPerColumn)
logger.error("Result Rows size: " + splitRows.size)
logger.error("Result Row - first: " + (if (splitRows.size > 0) splitRows.head else "not found"))

val readerParser = new CsvParser(parserSettings)
val parsedInput = readerParser.parseAll(inputReader).asScala
splitRows
}

val fetchBatchInfo = (batchInfoId: String) => {
logger.error(">>> About to fetch Results in batchInfoId: " + batchInfoId)

val outputWriter = new StringWriter()
val resultIds = bulkAPI.getBatchResultIds(jobId, batchInfoId)
logger.error(">>> Got ResultsIds in batchInfoId: " + resultIds)
logger.error(">>> Got ResultsIds in batchInfoId.size: " + resultIds.size)
logger.error(">>> Got ResultsIds in Last Result Id: " + resultIds.get(resultIds.size() - 1))

val writerSettings = new CsvWriterSettings()
writerSettings.setQuoteAllFields(true)
writerSettings.setQuoteEscapingEnabled(true)
// val result = bulkAPI.getBatchResult(jobId, batchInfoId, resultIds.get(resultIds.size() - 1))

val writer = new CsvWriter(outputWriter, writerSettings)
parsedInput.foreach { writer.writeRow(_) }
// logger.error(">>> Got Results - Results (string) length: " + result.length)

outputWriter.toString.lines.toList
// Use Csv parser to split CSV by rows to cover edge cases (ex. escaped characters, new line within string, etc)
// def splitCsvByRows(csvString: String): Seq[String] = {
// // The CsvParser interface only interacts with IO, so StringReader and StringWriter
// val inputReader = new StringReader(csvString)
//
// val parserSettings = new CsvParserSettings()
// parserSettings.setLineSeparatorDetectionEnabled(true)
// parserSettings.getFormat.setNormalizedNewline(' ')
// parserSettings.setMaxCharsPerColumn(maxCharsPerColumn)
//
// val readerParser = new CsvParser(parserSettings)
// val parsedInput = readerParser.parseAll(inputReader).asScala
//
// val outputWriter = new StringWriter()
//
// val writerSettings = new CsvWriterSettings()
// writerSettings.setQuoteAllFields(true)
// writerSettings.setQuoteEscapingEnabled(true)
//
// val writer = new CsvWriter(outputWriter, writerSettings)
// parsedInput.foreach { writer.writeRow(_) }
//
// outputWriter.toString.lines.toList
// }

val resultIdsBatchInfoIdPairs: List[(String, String)] = resultIds.toList.map { resultId: String => {
(resultId, batchInfoId)
}}

// AS addition - START
// val allRows: Seq[String] = resultIds.toList.flatMap { resultId: String => {
// logger.error("Getting Result for ResultId: " + resultId)
// val result = bulkAPI.getBatchResult(jobId, batchInfoId, resultId)
//
// val splitRows = splitCsvByRows(result)
//
// logger.error("Result Rows size: " + splitRows.size)
// logger.error("Result Row - first: " + (if (splitRows.size > 0) splitRows.head else "not found"))
//
// splitRows
// }}

val allRows: Seq[String] = resultIdsBatchInfoIdPairs.flatMap { case(resultId, batchInfoId) =>
fetchAllResults(resultId, batchInfoId)
}

splitCsvByRows(result)
allRows
// AS Addition - END

// val splitRows = splitCsvByRows(result)
// logger.error("Result Rows size: " + splitRows.size)
// logger.error("Result Row - first: " + (if (splitRows.size > 0) splitRows.head else "not found"))
// splitRows

}

// AS addition - START
val csvData: Dataset[String] = if (completedBatchInfoIds.size == 1) {
val resultIds = bulkAPI.getBatchResultIds(jobId, completedBatchInfoIds.head)

val resultIdsCompletedBatchInfoIdPairs: List[(String, String)] = resultIds.toList.map { resultId: String => {
(resultId, completedBatchInfoIds.head)
}}

logger.error(">>>> Will Parallelize Result IDs, CBatchInfoId: " + resultIdsCompletedBatchInfoIdPairs)

sqlContext
.sparkContext
.parallelize(resultIdsCompletedBatchInfoIdPairs)
.flatMap { case (resultId, batchInfoId) =>
fetchAllResults(resultId, batchInfoId)
}.toDS()
} else {
logger.error(">>>> Will Parallelize CompletedBatchInfoIds: " + completedBatchInfoIds)

sqlContext
.sparkContext
.parallelize(completedBatchInfoIds)
.flatMap(fetchBatchInfo).toDS()
}
// AS addition - END

val csvData = sqlContext
.sparkContext
.parallelize(completedBatchInfoIds)
.flatMap(fetchBatchInfo).toDS()
// val csvData = sqlContext
// .sparkContext
// .parallelize(completedBatchInfoIds)
// .flatMap(fetchBatchInfo).toDS()

sqlContext
.sparkSession
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,27 +189,32 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr
if (sfObject.isEmpty) {
throw new Exception("sfObject must not be empty when performing bulk query")
}
logger.info("createBulkRelation :: sfObject: " + sfObject)

val maxCharsPerColumnStr = parameters.getOrElse("maxCharsPerColumn", "4096")
val maxCharsPerColumn = try {
maxCharsPerColumnStr.toInt
} catch {
case e: Exception => throw new Exception("maxCharsPerColumn must be a valid integer")
}
logger.info("createBulkRelation :: maxCharsPerColumn: " + maxCharsPerColumn)

val timeoutStr = parameters.getOrElse("timeout", "600000")
val timeout = try {
timeoutStr.toLong
} catch {
case e: Exception => throw new Exception("timeout must be a valid integer")
}
logger.info("createBulkRelation :: timeout: " + timeout)

var customHeaders = ListBuffer[Header]()
val pkChunkingStr = parameters.getOrElse("pkChunking", "false")
val pkChunking = flag(pkChunkingStr, "pkChunkingStr")
logger.info("createBulkRelation :: pkChunking: " + pkChunking)

if (pkChunking) {
val chunkSize = parameters.get("chunkSize")
logger.info("createBulkRelation :: chunkSize: " + chunkSize)

if (!chunkSize.isEmpty) {
try {
Expand All @@ -219,6 +224,7 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr
case e: Exception => throw new Exception("chunkSize must be a valid integer")
}
customHeaders += new BasicHeader("Sforce-Enable-PKChunking", s"chunkSize=${chunkSize.get}")
// customHeaders += new BasicHeader("Sforce-Enable-PKChunking", s"chunkSize=${chunkSize.get}; parent=Account")
} else {
customHeaders += new BasicHeader("Sforce-Enable-PKChunking", "true")
}
Expand Down