diff --git a/connector/src/main/scala/com/basho/riak/spark/query/QueryTS.scala b/connector/src/main/scala/com/basho/riak/spark/query/QueryTS.scala index 41e9576c..52ec88d9 100644 --- a/connector/src/main/scala/com/basho/riak/spark/query/QueryTS.scala +++ b/connector/src/main/scala/com/basho/riak/spark/query/QueryTS.scala @@ -17,25 +17,34 @@ */ package com.basho.riak.spark.query -import java.sql.Timestamp import java.util.concurrent.ExecutionException + import com.basho.riak.client.core.netty.RiakResponseException import com.basho.riak.client.core.operations.ts.QueryOperation -import com.basho.riak.client.core.query.timeseries.{ Row, ColumnDescription } -import com.basho.riak.client.core.util.BinaryValue -import com.basho.riak.spark.rdd.connector.RiakSession -import com.basho.riak.spark.rdd.{ BucketDef, ReadConf } +import com.basho.riak.client.core.query.timeseries.{ColumnDescription, Row} + import scala.collection.convert.decorateAsScala._ import com.basho.riak.client.core.query.timeseries.CoverageEntry import com.basho.riak.spark.rdd.connector.RiakConnector import com.basho.riak.client.core.util.HostAndPort +import com.basho.riak.spark.util.{Dumpable, DumpUtils} /** * @author Sergey Galkin * @since 1.1.0 */ -case class TSQueryData(sql: String, coverageEntry: Option[CoverageEntry] = None) { +case class TSQueryData(sql: String, coverageEntry: Option[CoverageEntry] = None) extends Dumpable { val primaryHost = coverageEntry.map(e => HostAndPort.fromParts(e.getHost, e.getPort)) + + override def dump(lineSep: String = "\n"): String = { + val optional = coverageEntry match { + case Some(ce) => lineSep + s"primary-host: ${primaryHost.get.getHost}:${primaryHost.get.getPort}" + lineSep + + "coverage-entry:" + DumpUtils.dump(ce, lineSep + " ") + case None => "" + } + + s"sql: {${sql.toLowerCase.replaceAll("\n", "")}}" + optional + } } /** diff --git a/connector/src/main/scala/com/basho/riak/spark/rdd/RiakPartition.scala b/connector/src/main/scala/com/basho/riak/spark/rdd/RiakPartition.scala index ec9627c3..d2ca4daa 100644 --- a/connector/src/main/scala/com/basho/riak/spark/rdd/RiakPartition.scala +++ b/connector/src/main/scala/com/basho/riak/spark/rdd/RiakPartition.scala @@ -18,9 +18,10 @@ package com.basho.riak.spark.rdd import com.basho.riak.client.core.util.HostAndPort +import com.basho.riak.spark.util.Dumpable import org.apache.spark.Partition -trait RiakPartition extends Partition{ +trait RiakPartition extends Partition with Dumpable { def endpoints: Iterable[HostAndPort] } diff --git a/connector/src/main/scala/com/basho/riak/spark/rdd/partitioner/PartitioningUtils.scala b/connector/src/main/scala/com/basho/riak/spark/rdd/partitioner/PartitioningUtils.scala index dc8447a4..824a3670 100644 --- a/connector/src/main/scala/com/basho/riak/spark/rdd/partitioner/PartitioningUtils.scala +++ b/connector/src/main/scala/com/basho/riak/spark/rdd/partitioner/PartitioningUtils.scala @@ -7,7 +7,12 @@ object PartitioningUtils { def splitListEvenly[A](list: Seq[A], splitCount: Int): Iterator[Seq[A]] = { val (base, rem) = divide(list.size, splitCount) val (smaller, bigger) = list.splitAt(list.size - rem * (base + 1)) - smaller.grouped(base) ++ bigger.grouped(base + 1) + + if (smaller.isEmpty) { + bigger.grouped(base + 1) + } else { + smaller.grouped(base) ++ bigger.grouped(base + 1) + } } // e.g. split 64 coverage entries into 10 partitions: (6,6,6,6,6,6,7,7,7,7) coverage entries in partitions respectively @@ -23,11 +28,11 @@ object PartitioningUtils { yield if (i < rem) base + 1 else base } - def divide(size: Long, splitCount: Int): (Long, Long) = { + private def divide(size: Long, splitCount: Int): (Long, Long) = { (size / splitCount, size % splitCount) } - def divide(size: Int, splitCount: Int): (Int, Int) = { + private def divide(size: Int, splitCount: Int): (Int, Int) = { (size / splitCount, size % splitCount) } diff --git a/connector/src/main/scala/com/basho/riak/spark/rdd/partitioner/RiakTSPartitioner.scala b/connector/src/main/scala/com/basho/riak/spark/rdd/partitioner/RiakTSPartitioner.scala index 0141a783..98617c48 100644 --- a/connector/src/main/scala/com/basho/riak/spark/rdd/partitioner/RiakTSPartitioner.scala +++ b/connector/src/main/scala/com/basho/riak/spark/rdd/partitioner/RiakTSPartitioner.scala @@ -18,19 +18,22 @@ package com.basho.riak.spark.rdd.partitioner import java.sql.Timestamp -import org.apache.spark.Partition + +import org.apache.spark.{Logging, Partition} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{ StructType, TimestampType } +import org.apache.spark.sql.types.StructType import com.basho.riak.client.core.util.HostAndPort import com.basho.riak.spark.query.TSQueryData -import com.basho.riak.spark.rdd.{ ReadConf, RiakPartition } +import com.basho.riak.spark.rdd.{ReadConf, RiakPartition} import com.basho.riak.client.core.netty.RiakResponseException import com.basho.riak.client.api.commands.timeseries.CoveragePlan import com.basho.riak.spark.rdd.partitioner.PartitioningUtils._ import com.basho.riak.spark.rdd.connector.RiakConnector + import scala.collection.JavaConversions._ import scala.util.control.Exception._ import com.basho.riak.client.core.query.timeseries.CoverageEntry +import com.basho.riak.spark.util.DumpUtils /** * @author Sergey Galkin @@ -38,7 +41,22 @@ import com.basho.riak.client.core.query.timeseries.CoverageEntry case class RiakTSPartition( index: Int, endpoints: Iterable[HostAndPort], - queryData: Seq[TSQueryData]) extends RiakPartition + queryData: Seq[TSQueryData]) extends RiakPartition { + + override def dump(lineSep: String = "\n"): String = + s"[$index] eps: " + endpoints.foldLeft(new StringBuilder) { + (sb, h) => { + if (!sb.isEmpty) { + sb.append(',').append(' ') + } + + sb append h.getHost append (':') append (h.getPort) + } + }.append('\n') + .append(s" queryData (${queryData.size}):\n") + .append(queryData.foldLeft(new StringBuilder) { (sb, qd) => sb.append(" ").append(qd.dump("\n ")).append("\n\n") }) + .toString() +} trait RiakTSPartitioner { @@ -193,7 +211,6 @@ object RangedRiakTSPartitioner { columnNames: Option[Seq[String]], filters: Array[Filter], readConf: ReadConf): RangedRiakTSPartitioner = { new RiakTSCoveragePlanBasedPartitioner(connector, tableName, schema, columnNames, filters, readConf) } - } /** Splits initial range query into readConf.splitCount number of sub-ranges, each in a separate partition */ @@ -283,7 +300,8 @@ class AutomaticRangedRiakTSPartitioner(connector: RiakConnector, tableName: Stri } class RiakTSCoveragePlanBasedPartitioner(connector: RiakConnector, tableName: String, schema: Option[StructType], - columnNames: Option[Seq[String]], filters: Array[Filter], readConf: ReadConf) extends RangedRiakTSPartitioner(tableName, schema, columnNames, filters, readConf) { + columnNames: Option[Seq[String]], filters: Array[Filter], readConf: ReadConf) extends RangedRiakTSPartitioner(tableName, schema, columnNames, filters, readConf) +with Logging { val where = whereClause(filters) val (queryRaw, vals) = toSql(columnNames, tableName, schema, where) @@ -306,16 +324,27 @@ class RiakTSCoveragePlanBasedPartitioner(connector: RiakConnector, tableName: St override lazy val tsRangeFieldName = coveragePlan.head.getFieldName override def partitions(): Array[Partition] = { - val hosts = coveragePlan.hosts require(splitCount >= hosts.size) val coverageEntriesCount = coveragePlan.size val partitionsCount = if (splitCount <= coverageEntriesCount) splitCount else coverageEntriesCount - val evenDistributionBetweenHosts = distributeEvenly(partitionsCount, hosts.size) + if (log.isTraceEnabled()) { + val cp = coveragePlan.foldLeft(new StringBuilder) { (sb, ce) => sb.append( DumpUtils.dump(ce, "\n ")).append("\n\n") } + + logTrace("\n----------------------------------------\n" + + s" [Auto TS Partitioner] Requested: split up to $splitCount partitions\n" + + s" Actually: the only $partitionsCount partitions might be created\n" + + "--\n" + + s"Coverage plan ($coverageEntriesCount coverage entries):\n$cp\n" + + "----------------------------------------\n") + } + + val evenPartitionDistributionBetweenHosts = distributeEvenly(partitionsCount, hosts.size) + val numberOfEntriesInPartitionPerHost = - (hosts zip evenDistributionBetweenHosts).flatMap { case (h, num) => splitListEvenly(coveragePlan.hostEntries(h), num).map((h, _)) } + (hosts zip evenPartitionDistributionBetweenHosts) flatMap { case (h, num) => splitListEvenly(coveragePlan.hostEntries(h), num) map{(h, _)} } val partitions = for { ((host, coverageEntries), partitionIdx) <- numberOfEntriesInPartitionPerHost.zipWithIndex @@ -323,6 +352,23 @@ class RiakTSCoveragePlanBasedPartitioner(connector: RiakConnector, tableName: St partition = RiakTSPartition(partitionIdx, hosts.toSet, tsQueryData) } yield partition - partitions.toArray + val result = partitions.toArray + + if (log.isDebugEnabled()) { + val p = result.foldLeft(new StringBuilder) { (sb, r) => sb.append(r.dump()).append("\n") }.toString() + + logInfo("\n----------------------------------------\n" + + s" [Auto TS Partitioner] Requested: split up to $splitCount partitions\n" + + s" Actually: the created partitions are:\n" + + "--\n" + + s"$p\n" + + "----------------------------------------\n") + } + + // Double check that all coverage entries were used + val numberOfUsedCoverageEntries = partitions.foldLeft(0){ (sum, p) => sum + p.queryData.size} + require( numberOfUsedCoverageEntries == coverageEntriesCount) + + result.asInstanceOf[Array[Partition]] } } diff --git a/connector/src/main/scala/com/basho/riak/spark/util/Dumpable.scala b/connector/src/main/scala/com/basho/riak/spark/util/Dumpable.scala new file mode 100644 index 00000000..8bdeb18d --- /dev/null +++ b/connector/src/main/scala/com/basho/riak/spark/util/Dumpable.scala @@ -0,0 +1,47 @@ +/** + * Copyright (c) 2015-2016 Basho Technologies, Inc. + * + * This file is provided to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file + * except in compliance with the License. You may obtain + * a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package com.basho.riak.spark.util + +import com.basho.riak.client.core.query.timeseries.CoverageEntry + +trait Dumpable { + def dump(lineSep: String = "\n"): String = toString +} + + +object DumpUtils { + + def dump(ce: CoverageEntry, lineSep: String): String = { + val lb = ce.isLowerBoundInclusive match { + case true => "[" + case false => "(" + } + + val ub = ce.isUpperBoundInclusive match { + case true => "]" + case false => ")" + } + + s"$lb${ce.getLowerBound},${ce.getUpperBound}$ub@host: ${ce.getHost}:${ce.getPort}" + lineSep + + s"description: ${ce.getDescription}" + lineSep + + s"context: " + {ce.getCoverageContext match { + case null => "null" + case c => c.map("%02X" format _).mkString + }} + } +} \ No newline at end of file diff --git a/connector/src/test/scala/com/basho/riak/spark/rdd/partitioner/RiakTSCoveragePlanBasedPartitionerTest.scala b/connector/src/test/scala/com/basho/riak/spark/rdd/partitioner/RiakTSCoveragePlanBasedPartitionerTest.scala new file mode 100644 index 00000000..141329b2 --- /dev/null +++ b/connector/src/test/scala/com/basho/riak/spark/rdd/partitioner/RiakTSCoveragePlanBasedPartitionerTest.scala @@ -0,0 +1,265 @@ +/** + * ***************************************************************************** + * Copyright (c) 2016 IBM Corp. + * + * Created by Basho Technologies for IBM + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ***************************************************************************** + */ +package com.basho.riak.spark.rdd.partitioner + + +import com.basho.riak.JsonTestFunctions +import com.basho.riak.client.api.commands.timeseries.CoveragePlan +import com.basho.riak.client.core.query.timeseries.{CoverageEntry, CoveragePlanResult} +import com.basho.riak.spark.rdd.{ReadConf, RegressionTests, RiakTSRDD} +import com.basho.riak.spark.rdd.connector.{RiakConnector, RiakSession} +import com.fasterxml.jackson.core.{JsonGenerator, Version} +import com.fasterxml.jackson.databind.module.SimpleModule +import com.fasterxml.jackson.databind.{JsonSerializer, ObjectMapper, SerializerProvider} +import org.apache.spark.SparkContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.sources.{Filter, GreaterThanOrEqual, LessThan} +import org.junit.{Before, Test} +import org.junit.experimental.categories.Category +import org.junit.runner.RunWith +import org.mockito.Matchers._ +import org.mockito.Mock +import org.mockito.runners.MockitoJUnitRunner +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import scala.collection.JavaConversions._ + +@RunWith(classOf[MockitoJUnitRunner]) +class RiakTSCoveragePlanBasedPartitionerTest extends JsonTestFunctions { + + @Mock + private val rc: RiakConnector = null + + @Mock + private val rs: RiakSession = null + + @Mock + private val sc: SparkContext = null + + // To access protected constructor in CoveragePlanResult + class SimpleCoveragePlanResult extends CoveragePlanResult { + } + + val filters: Array[Filter] = Array( + GreaterThanOrEqual("time", 0), + LessThan("time", 1000)) + + private var coveragePlan: SimpleCoveragePlanResult = null + + override protected def tolerantMapper: ObjectMapper = super.tolerantMapper + .registerModule( + new SimpleModule("RiakTs2 Module", new Version(1, 0, 0, null)) + .addSerializer(classOf[RiakTSPartition], new RiakTSPartitionSerializer) + .addSerializer(classOf[CoverageEntry], new RiakCoverageEntrySerializer)) + + @Before + def initializeMocks(): Unit = { + doAnswer(new Answer[CoveragePlanResult] { + override def answer(invocation: InvocationOnMock): CoveragePlanResult = coveragePlan + }).when(rs).execute(any[CoveragePlan]) + + doAnswer(new Answer[AnyRef] { + override def answer(invocation: InvocationOnMock) = { + val func = invocation.getArguments()(0).asInstanceOf[RiakSession => AnyRef] + func.apply(rs) + } + }).when(rc).withSessionDo(any(classOf[Function1[RiakSession, CoveragePlanResult]])) + + coveragePlan = new SimpleCoveragePlanResult + } + + @Test + @Category(Array(classOf[RegressionTests])) + def checkPartitioningForIrregularData(): Unit = { + + // host -> range(from->to) + makeCoveragePlan( + ("h1", 1->2), + ("h2", 3->4), + ("h2", 5->6), + ("h2", 7->8), + ("h3", 11->12) + ) + + val partitioner = new RiakTSCoveragePlanBasedPartitioner(rc, "test", None, None, new Array[Filter](0), new ReadConf()) + val partitions = partitioner.partitions() + assertEqualsUsingJSONIgnoreOrder( + """[ + | {index: 0, queryData: {primaryHost: 'h3:0', entry: '[11,12)@h3'}}, + | + | {index: 1, queryData:[ + | {primaryHost: 'h2:0', entry: '[5,6)@h2'}, + | {primaryHost: 'h2:0', entry: '[7,8)@h2'}]}, + | + | {index: 2, queryData: {primaryHost: 'h1:0', entry: '[1,2)@h1'}}, + | + | {index: 3, queryData: {primaryHost: 'h2:0', entry: '[3,4)@h2'}} + ]""".stripMargin, partitions) + } + + @Test + @Category(Array(classOf[RegressionTests])) + def checkPartitioningForRegullarData(): Unit = { + + // host -> range(from->to) + makeCoveragePlan( + ("1", 1 -> 2), + ("1", 3 -> 4), + ("2", 5 -> 6), + ("2", 7 -> 8), + ("3", 9 -> 10) + ) + + val partitioner = new RiakTSCoveragePlanBasedPartitioner(rc, "test", None, None, new Array[Filter](0), new ReadConf()) + val partitions = partitioner.partitions() + + + assertEqualsUsingJSONIgnoreOrder( + """[ + | {index: 0, queryData: {primaryHost: '1:0', entry: '[1,2)@1'}}, + | {index: 1, queryData: {primaryHost: '1:0', entry: '[3,4)@1'}}, + | {index: 2, queryData: {primaryHost: '2:0', entry: '[5,6)@2'}}, + | {index: 3, queryData: {primaryHost: '3:0', entry: '[9,10)@3'}}, + | {index: 4, queryData: {primaryHost: '2:0', entry: '[7,8)@2'}} + ]""".stripMargin, partitions) + } + + @Test + def coveragePlanBasedPartitioningLessThanSplitCount(): Unit = { + makeCoveragePlan( + ("h1", 1 -> 2), + ("h2", 3 -> 4), + ("h3", 5 -> 6) + ) + + val rdd = new RiakTSRDD[Row](sc, rc, "test", None, None, None, filters) + val partitions = rdd.partitions + assertEqualsUsingJSONIgnoreOrder( + """[ + | {index: 0, queryData: {primaryHost: 'h3:0', entry: '[5,6)@h3'}}, + | {index: 1, queryData: {primaryHost: 'h1:0', entry: '[1,2)@h1'}}, + | {index: 2, queryData: {primaryHost: 'h2:0', entry: '[3,4)@h2'}} + ]""".stripMargin, partitions) + } + + @Test + def coveragePlanBasedPartitioningGreaterThanSplitCount(): Unit = { + val requestedSplitCount = 3 + + makeCoveragePlan( + ("h1", 1 -> 2), + ("h1", 3 -> 4), + ("h1", 5 -> 6), + ("h2", 6 -> 7), + ("h2", 8 -> 9), + ("h2", 10 -> 11), + ("h2", 12 -> 13), + ("h3", 14 -> 15), + ("h3", 16 -> 17), + ("h3", 18 -> 19) + ) + val rdd = new RiakTSRDD[Row](sc, rc, "test", None, None, None, filters, readConf = ReadConf(splitCount=requestedSplitCount)) + val partitions = rdd.partitions + + assertEqualsUsingJSONIgnoreOrder("""[ + | {index:0, queryData:[ + | {primaryHost: 'h3:0', entry: '[14,15)@h3'}, + | {primaryHost: 'h3:0', entry: '[16,17)@h3'}, + | {primaryHost: 'h3:0', entry: '[18,19)@h3'}]}, + | + | {index:1,queryData:[ + | {primaryHost: 'h2:0', entry: '[6,7)@h2'}, + | {primaryHost: 'h2:0', entry: '[8,9)@h2'}, + | {primaryHost: 'h2:0', entry: '[10,11)@h2'}, + | {primaryHost: 'h2:0', entry: '[12,13)@h2'}]}, + | + | {index:2,queryData:[ + | {primaryHost: 'h1:0', entry: '[1,2)@h1'}, + | {primaryHost: 'h1:0', entry: '[3,4)@h1'}, + | {primaryHost: 'h1:0', entry: '[5,6)@h1'}]} + ]""".stripMargin, partitions) + } + + private def makeCoveragePlan(entries: Tuple2[String, Tuple2[Int, Int]]*): Unit = { + coveragePlan = new SimpleCoveragePlanResult + + entries.foreach(e => { + val (host, range) = e + + val ce = new CoverageEntry() + ce.setFieldName("time") + ce.setHost(host) + ce.setLowerBoundInclusive(true) + ce.setLowerBound(range._1) + ce.setUpperBoundInclusive(false) + ce.setUpperBound(range._2) + + ce.setDescription(s"table / time >= ${range._1} AND time < ${range._2}") + + coveragePlan.addEntry(ce) + }) + } + + private class RiakTSPartitionSerializer extends JsonSerializer[RiakTSPartition] { + override def serialize(value: RiakTSPartition, jgen: JsonGenerator, provider: SerializerProvider): Unit = { + jgen.writeStartObject() + jgen.writeNumberField("index", value.index) + + jgen.writeFieldName("queryData") + + if (value.queryData.length >1) { + jgen.writeStartArray() + } + + value.queryData.foreach(qd => { + jgen.writeStartObject() + if (qd.primaryHost.isDefined) { + jgen.writeObjectField("primaryHost", qd.primaryHost.get) + } + jgen.writeObjectField("entry", qd.coverageEntry) + jgen.writeEndObject() + }) + + if (value.queryData.length >1) { + jgen.writeEndArray() + } + + jgen.writeEndObject() + } + } + + class RiakCoverageEntrySerializer extends JsonSerializer[CoverageEntry] { + override def serialize(ce: CoverageEntry, jgen: JsonGenerator, provider: SerializerProvider): Unit = { + val lb = ce.isLowerBoundInclusive match { + case true => "[" + case false => "(" + } + + val ub = ce.isUpperBoundInclusive match { + case true => "]" + case false => ")" + } + + jgen.writeString(s"$lb${ce.getLowerBound},${ce.getUpperBound}$ub@${ce.getHost}") + } + } +} \ No newline at end of file diff --git a/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TimeSeriesPartitioningTest.scala b/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TimeSeriesPartitioningTest.scala index d4675624..b16cb1b9 100644 --- a/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TimeSeriesPartitioningTest.scala +++ b/connector/src/test/scala/com/basho/riak/spark/rdd/timeseries/TimeSeriesPartitioningTest.scala @@ -20,28 +20,11 @@ package com.basho.riak.spark.rdd.timeseries import java.sql.Timestamp -import java.util - -import scala.collection.JavaConversions._ -import scala.collection.JavaConverters._ -import org.apache.spark.SparkConf -import org.apache.spark.sql.Row import org.apache.spark.sql.sources.EqualTo import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.sources.GreaterThanOrEqual import org.apache.spark.sql.sources.LessThan import org.junit.Test -import org.mockito.Matchers._ -import org.mockito.Mockito._ -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer -import com.basho.riak.client.api.commands.timeseries.CoveragePlan -import com.basho.riak.client.core.query.timeseries.CoverageEntry -import com.basho.riak.client.core.query.timeseries.CoveragePlanResult -import com.basho.riak.client.core.util.HostAndPort -import com.basho.riak.spark.rdd.RiakTSRDD -import com.basho.riak.spark.rdd.connector.RiakConnector -import com.basho.riak.spark.rdd.connector.RiakSession import com.basho.riak.spark.rdd.partitioner.RiakTSPartition import com.basho.riak.spark.toSparkContextFunctions import org.junit.Assert.assertEquals @@ -250,70 +233,4 @@ class TimeSeriesPartitioningTest extends AbstractTimeSeriesTest(createTestData = .filter(s"user_id = 'user1'") val partitions = df.rdd.partitions } - - @Test - def coveragePlanBasedPartitioningLessThanSplitCount(): Unit = { - - val connector = spy(RiakConnector(new SparkConf())) - val session = mock(classOf[RiakSession]) - val coveragePlan = mock(classOf[CoveragePlanResult]) - doReturn(session).when(connector).openSession(Some(Seq(any(classOf[HostAndPort])))) - when(session.execute(any(classOf[CoveragePlan]))).thenReturn(coveragePlan) - - val mapCE = createCoverageEntry(3, 3) - mapCE.keys.foreach(h => when(coveragePlan.hostEntries(h)).thenAnswer(new Answer[util.List[CoverageEntry]] { - override def answer(invocation: InvocationOnMock): util.List[CoverageEntry] = - mapCE(h).asJava - })) - when(coveragePlan.iterator()).thenAnswer(new Answer[util.Iterator[_]] { - override def answer(invocation: InvocationOnMock): util.Iterator[_] = mapCE.values.flatten.iterator.asJava - }) - when(coveragePlan.hosts()).thenReturn(setAsJavaSet(mapCE.keySet)) - val rdd = new RiakTSRDD[Row](sc, connector, bucketName, Some(schema), None, None, filters) - val partitions = rdd.partitions - - assertEquals(3, partitions.size) - } - - @Test - def coveragePlanBasedPartitioningGreaterThanSplitCount(): Unit = { - - val connector = spy(RiakConnector(new SparkConf())) - val session = mock(classOf[RiakSession]) - val coveragePlan = mock(classOf[CoveragePlanResult]) - doReturn(session).when(connector).openSession(Some(Seq(any(classOf[HostAndPort])))) - when(session.execute(any(classOf[CoveragePlan]))).thenReturn(coveragePlan) - - val mapCE = createCoverageEntry(100, 3) - mapCE.keys.foreach(h => when(coveragePlan.hostEntries(h)).thenAnswer(new Answer[util.List[CoverageEntry]] { - override def answer(invocation: InvocationOnMock): util.List[CoverageEntry] = - mapCE(h).asJava - })) - when(coveragePlan.iterator()).thenAnswer(new Answer[util.Iterator[_]] { - override def answer(invocation: InvocationOnMock): util.Iterator[_] = mapCE.values.flatten.iterator.asJava - }) - - when(coveragePlan.hosts()).thenReturn(setAsJavaSet(mapCE.keySet)) - val rdd = new RiakTSRDD[Row](sc, connector, bucketName, Some(schema), None, None, filters) - val partitions = rdd.partitions - - assertEquals(10, partitions.size) - } - - private def createCoverageEntry(numOfEntries: Int, numOfHosts: Int): Map[HostAndPort, IndexedSeq[CoverageEntry]] = { - val ces = (1 to numOfEntries).map { i => - val ce = new CoverageEntry(); - ce.setCoverageContext(Array.emptyByteArray); - ce.setFieldName("time"); - ce.setLowerBound(i * 1000); - ce.setLowerBoundInclusive(true); - ce.setUpperBound(i * 2000); - ce.setUpperBoundInclusive(false); - ce.setDescription(s"${bucketName} / time >= ${i * 1000} time < ${i * 200}"); - ce.setHost("localhost"); - ce.setPort((8080 + i % numOfHosts).toInt); - ce - } - ces.groupBy(ce => HostAndPort.fromParts(ce.getHost, ce.getPort)) - } } \ No newline at end of file diff --git a/test-utils/src/main/scala/com/basho/riak/JsonTestFunctions.scala b/test-utils/src/main/scala/com/basho/riak/JsonTestFunctions.scala index 26cc2879..169c2558 100644 --- a/test-utils/src/main/scala/com/basho/riak/JsonTestFunctions.scala +++ b/test-utils/src/main/scala/com/basho/riak/JsonTestFunctions.scala @@ -25,6 +25,7 @@ import com.fasterxml.jackson.databind.{JsonSerializer, ObjectMapper, SerializerP import net.javacrumbs.jsonunit.JsonAssert import net.javacrumbs.jsonunit.core.{Configuration, Option => JsonUnitOption} import com.basho.riak.client.core.query.timeseries.{Cell => RiakCell, Row => RiakRow} +import com.basho.riak.client.core.util.HostAndPort import com.fasterxml.jackson.databind.module.SimpleModule import scala.collection.JavaConversions._ @@ -35,7 +36,8 @@ trait JsonTestFunctions extends JsonFunctions { .registerModule( new SimpleModule("RiakTs Module", new Version(1,0,0,null)) .addSerializer(classOf[RiakCell], new RiakCellSerializer) - .addSerializer(classOf[RiakRow], new RiakRowSerializer)) + .addSerializer(classOf[RiakRow], new RiakRowSerializer) + .addSerializer(classOf[HostAndPort], new HostAndPortSerializer)) protected def assertEqualsUsingJSON(jsonExpected: AnyRef, actual: AnyRef): Unit = { @@ -74,7 +76,7 @@ trait JsonTestFunctions extends JsonFunctions { } } - class RiakRowSerializer extends JsonSerializer[RiakRow] { + private class RiakRowSerializer extends JsonSerializer[RiakRow] { override def serialize(row: RiakRow, jgen: JsonGenerator, provider: SerializerProvider): Unit = { if (row == null) { jgen.writeNull() @@ -86,7 +88,7 @@ trait JsonTestFunctions extends JsonFunctions { } } - class RiakCellSerializer extends JsonSerializer[RiakCell] { + private class RiakCellSerializer extends JsonSerializer[RiakCell] { override def serialize(cell: RiakCell, jgen: JsonGenerator, provider: SerializerProvider): Unit = { if (cell == null) { jgen.writeNull() @@ -109,4 +111,9 @@ trait JsonTestFunctions extends JsonFunctions { } } } + + private class HostAndPortSerializer extends JsonSerializer[HostAndPort] { + override def serialize(value: HostAndPort, jgen: JsonGenerator, provider: SerializerProvider): Unit = + jgen.writeString(value.getHost + ":" + value.getPort) + } }