Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARKC-712: Update metrics for read and writes via DSV2 #1369

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright DataStax, Inc.
*
* Please see the included license file for details.
*/

package com.datastax.spark.connector.datasource

import scala.collection.mutable
import com.datastax.spark.connector._
import com.datastax.spark.connector.cluster.DefaultCluster
import com.datastax.spark.connector.cql.CassandraConnector
import org.scalatest.BeforeAndAfterEach
import com.datastax.spark.connector.datasource.CassandraCatalog
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import com.datastax.spark.connector.cql.CassandraConnector
import org.apache.spark.sql.SparkSession


class CassandraCatalogMetricsSpec extends SparkCassandraITFlatSpecBase with DefaultCluster with BeforeAndAfterEach {

override lazy val conn = CassandraConnector(defaultConf)

override lazy val spark = SparkSession.builder()
.config(sparkConf
// Enable Codahale/Dropwizard metrics
.set("spark.metrics.conf.executor.source.cassandra-connector.class", "org.apache.spark.metrics.CassandraConnectorSource")
.set("spark.metrics.conf.driver.source.cassandra-connector.class", "org.apache.spark.metrics.CassandraConnectorSource")
.set("spark.sql.sources.useV1SourceList", "")
.set("spark.sql.defaultCatalog", "cassandra")
.set("spark.sql.catalog.cassandra", classOf[CassandraCatalog].getCanonicalName)
)
.withExtensions(new CassandraSparkExtensions).getOrCreate().newSession()

override def beforeClass {
conn.withSessionDo { session =>
session.execute(s"CREATE KEYSPACE IF NOT EXISTS $ks WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }")
session.execute(s"CREATE TABLE IF NOT EXISTS $ks.leftjoin (key INT, x INT, PRIMARY KEY (key))")
for (i <- 1 to 1000 * 10) {
session.execute(s"INSERT INTO $ks.leftjoin (key, x) values ($i, $i)")
}
}
}

var readRowCount: Long = 0
var readByteCount: Long = 0

it should "update Codahale read metrics for SELECT queries" in {
val df = spark.sql(s"SELECT x FROM $ks.leftjoin LIMIT 2")
val metricsRDD = df.queryExecution.toRdd.mapPartitions { iter =>
val tc = org.apache.spark.TaskContext.get()
val source = org.apache.spark.metrics.MetricsUpdater.getSource(tc)
Iterator((source.get.readRowMeter.getCount, source.get.readByteMeter.getCount))
}

val metrics = metricsRDD.collect()
readRowCount = metrics.map(_._1).sum - readRowCount
readByteCount = metrics.map(_._2).sum - readByteCount

assert(readRowCount > 0)
assert(readByteCount == readRowCount * 4) // 4 bytes per INT result
}

it should "update Codahale read metrics for COUNT queries" in {
val df = spark.sql(s"SELECT COUNT(*) FROM $ks.leftjoin")
val metricsRDD = df.queryExecution.toRdd.mapPartitions { iter =>
val tc = org.apache.spark.TaskContext.get()
val source = org.apache.spark.metrics.MetricsUpdater.getSource(tc)
Iterator((source.get.readRowMeter.getCount, source.get.readByteMeter.getCount))
}

val metrics = metricsRDD.collect()
readRowCount = metrics.map(_._1).sum - readRowCount
readByteCount = metrics.map(_._2).sum - readByteCount

assert(readRowCount > 0)
assert(readByteCount == readRowCount * 8) // 8 bytes per COUNT result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,17 @@ class CassandraCatalogTableReadSpec extends CassandraCatalogSpecBase {
it should "handle count pushdowns" in {
setupBasicTable()
val request = spark.sql(s"""SELECT COUNT(*) from $defaultKs.$testTable""")
val reader = request
var factory = request
.queryExecution
.executedPlan
.collectFirst {
case batchScanExec: BatchScanExec=> batchScanExec.readerFactory.createReader(EmptyInputPartition)
case adaptiveSparkPlanExec: AdaptiveSparkPlanExec => adaptiveSparkPlanExec.executedPlan.collectLeaves().collectFirst{
case batchScanExec: BatchScanExec=> batchScanExec.readerFactory.createReader(EmptyInputPartition)
}.get
case batchScanExec: BatchScanExec=> batchScanExec.readerFactory
case adaptiveSparkPlanExec: AdaptiveSparkPlanExec => adaptiveSparkPlanExec.executedPlan.collectLeaves().collectFirst{
case batchScanExec: BatchScanExec=> batchScanExec.readerFactory
}.get
}

reader.get.isInstanceOf[CassandraCountPartitionReader] should be (true)
factory.get.asInstanceOf[CassandraScanPartitionReaderFactory].isCountQuery should be (true)
request.collect()(0).get(0) should be (101)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.sources.In
import org.apache.spark.sql.types.{LongType, StructField, StructType}
import org.apache.spark.metrics.InputMetricsUpdater
import org.apache.spark.TaskContext

import scala.util.{Failure, Success}

Expand Down Expand Up @@ -62,16 +64,18 @@ abstract class CassandraBaseInJoinReader(
protected val maybeRateLimit = JoinHelper.maybeRateLimit(readConf)
protected val requestsPerSecondRateLimiter = JoinHelper.requestsPerSecondRateLimiter(readConf)

protected val metricsUpdater = InputMetricsUpdater(TaskContext.get(), readConf)
protected def pairWithRight(left: CassandraRow): SettableFuture[Iterator[(CassandraRow, InternalRow)]] = {
val resultFuture = SettableFuture.create[Iterator[(CassandraRow, InternalRow)]]
val leftSide = Iterator.continually(left)

queryExecutor.executeAsync(bsb.bind(left).executeAs(readConf.executeAs)).onComplete {
case Success(rs) =>
val resultSet = new PrefetchingResultSetIterator(rs)
val iteratorWithMetrics = resultSet.map(metricsUpdater.updateMetrics)
/* This is a much less than ideal place to actually rate limit, we are buffering
these futures this means we will most likely exceed our threshold*/
val throttledIterator = resultSet.map(maybeRateLimit)
val throttledIterator = iteratorWithMetrics.map(maybeRateLimit)
val rightSide = throttledIterator.map(rowReader.read(_, rowMetadata))
resultFuture.set(leftSide.zip(rightSide))
case Failure(throwable) =>
Expand Down Expand Up @@ -103,6 +107,7 @@ abstract class CassandraBaseInJoinReader(
override def get(): InternalRow = currentRow

override def close(): Unit = {
metricsUpdater.finish()
session.close()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import com.datastax.spark.connector.util.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.types.{LongType, StructField, StructType}
import org.apache.spark.metrics.InputMetricsUpdater
import org.apache.spark.TaskContext

case class CassandraScanPartitionReaderFactory(
connector: CassandraConnector,
Expand All @@ -20,10 +22,12 @@ case class CassandraScanPartitionReaderFactory(
readConf: ReadConf,
queryParts: CqlQueryParts) extends PartitionReaderFactory {

def isCountQuery: Boolean = queryParts.selectedColumnRefs.contains(RowCountRef)

override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {

val cassandraPartition = partition.asInstanceOf[CassandraPartition[Any, _ <: Token[Any]]]
if (queryParts.selectedColumnRefs.contains(RowCountRef)) {
if (isCountQuery) {
//Count Pushdown
CassandraCountPartitionReader(
connector,
Expand Down Expand Up @@ -61,6 +65,8 @@ abstract class CassandraPartitionReaderBase
protected val rowIterator = getIterator()
protected var lastRow: InternalRow = InternalRow()

protected val metricsUpdater = InputMetricsUpdater(TaskContext.get(), readConf)

override def next(): Boolean = {
if (rowIterator.hasNext) {
lastRow = rowIterator.next()
Expand All @@ -73,6 +79,7 @@ abstract class CassandraPartitionReaderBase
override def get(): InternalRow = lastRow

override def close(): Unit = {
metricsUpdater.finish()
scanner.close()
}

Expand Down Expand Up @@ -107,7 +114,8 @@ abstract class CassandraPartitionReaderBase
tokenRanges.iterator.flatMap { range =>
val scanResult = ScanHelper.fetchTokenRange(scanner, tableDef, queryParts, range, readConf.consistencyLevel, readConf.fetchSizeInRows)
val meta = scanResult.metadata
scanResult.rows.map(rowReader.read(_, meta))
val iteratorWithMetrics = scanResult.rows.map(metricsUpdater.updateMetrics)
iteratorWithMetrics.map(rowReader.read(_, meta))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory
import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.types.StructType
import org.apache.spark.metrics.OutputMetricsUpdater
import org.apache.spark.TaskContext

case class CassandraDriverDataWriterFactory(
connector: CassandraConnector,
Expand Down Expand Up @@ -36,22 +38,31 @@ case class CassandraDriverDataWriter(

private val columns = SomeColumns(inputSchema.fieldNames.map(name => ColumnName(name)): _*)

private val writer =
private val metricsUpdater = OutputMetricsUpdater(TaskContext.get(), writeConf)

private val asycWriter =
TableWriter(connector, tableDef, columns, writeConf, false)(unsafeRowWriterFactory)
.getAsyncWriter()

private val writer = asycWriter.copy(
successHandler = Some(metricsUpdater.batchFinished(success = true, _, _, _)),
failureHandler = Some(metricsUpdater.batchFinished(success = false, _, _, _)))

override def write(record: InternalRow): Unit = writer.write(record)

override def commit(): WriterCommitMessage = {
metricsUpdater.finish()
writer.close()
CassandraCommitMessage()
}

override def abort(): Unit = {
metricsUpdater.finish()
writer.close()
}

override def close(): Unit = {
metricsUpdater.finish()
//Our proxy Session Handler handles double closes by ignoring them so this is fine
writer.close()
}
Expand Down