Skip to content

Commit

Permalink
Supporting collect action on SGX #8
Browse files Browse the repository at this point in the history
  • Loading branch information
pgaref committed Jun 19, 2019
1 parent a554090 commit 128d51e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
27 changes: 25 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -943,8 +943,31 @@ abstract class RDD[T: ClassTag](
* all the data is loaded into the driver's memory.
*/
def collect(): Array[T] = withScope {
val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
Array.concat(results: _*)
if (sc.getConf.isSGXWorkerEnabled()) {
val toIteratorSizeSGXFunc = (itr: Iterator[Any]) => {
val result = new mutable.ArrayBuffer[Any]
itr.foreach(e => result.append(e))
result.toArray.iterator
}
val wrapped = new SGXRDD(this, toIteratorSizeSGXFunc, true)
// Results at this point are encrypted as Array[Byte]
val encryptedRes = sc.runJob(wrapped, (iter: Iterator[_]) => iter.toArray)
// In non-SGX driver just decrypt data here
if (!sc.getConf.isSGXDriverEnabled()) {
// Assuming that data are Longs
val dataIt = encryptedRes.toIterator.flatten
val toRet = new mutable.ArrayBuffer[T]
dataIt.foreach(e => toRet.append(e.asInstanceOf[T]))
toRet.toArray
}
// Send data to the SGX driver (to be decrypted there)
else {
throw new SGXException("Not implemented yet", new Exception("SGX exception"))
}
} else {
val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
Array.concat(results: _*)
}
}

/**
Expand Down
11 changes: 11 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuiteSGX.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ class RDDSuiteSGX extends SparkFunSuite {
assert(res == 4)
}

test("SGX shuffle operation") {
val kvPairs = sc.parallelize(Array(
("USA", 1), ("USA", 2), ("USA", 8), ("USA", 3),
("UK", 6), ("UK", 9), ("UK", 5), ("UK", 1),
("India", 4), ("India", 9), ("India", 4), ("India", 1)
), 1)
val res = kvPairs.groupByKey().map(s => (s._1, (s._2.sum)))
val resK = res.collect
assert(resK.size == 3)
}

test("SGX Iterator Reader test") {
val baos = new ByteArrayOutputStream
val dos = new DataOutputStream(baos)
Expand Down

0 comments on commit 128d51e

Please sign in to comment.