From 656d01c5e3b59a5eba49c193c083720d89ea22bb Mon Sep 17 00:00:00 2001 From: Artem Korsakov Date: Fri, 12 Jan 2024 13:29:32 +0300 Subject: [PATCH] Cast partitionNum to Int (#91) --- .../com/vesoft/nebula/algorithm/Main.scala | 2 +- .../nebula/algorithm/config/SparkConfig.scala | 4 +-- .../nebula/algorithm/reader/DataReader.scala | 27 ++++++++----------- .../nebula/algorithm/config/ConfigSuite.scala | 2 +- 4 files changed, 15 insertions(+), 20 deletions(-) diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/Main.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/Main.scala index 13ec599..0b2fc54 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/Main.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/Main.scala @@ -76,7 +76,7 @@ object Main { */ private[this] def createDataSource(spark: SparkSession, configs: Configs, - partitionNum: String): DataFrame = { + partitionNum: Int): DataFrame = { val dataSource = DataReader.make(configs) dataSource.read(spark, configs, partitionNum) } diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/SparkConfig.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/SparkConfig.scala index 8a9d60d..5a1908e 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/SparkConfig.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/SparkConfig.scala @@ -7,7 +7,7 @@ package com.vesoft.nebula.algorithm.config import org.apache.spark.sql.SparkSession -case class SparkConfig(spark: SparkSession, partitionNum: String) +case class SparkConfig(spark: SparkSession, partitionNum: Int) object SparkConfig { @@ -27,7 +27,7 @@ object SparkConfig { partitionNum = sparkConfigs.getOrElse("spark.app.partitionNum", "0") val spark = session.getOrCreate() validate(spark.version, "2.4.*") - SparkConfig(spark, partitionNum) + SparkConfig(spark, partitionNum.toInt) } def validate(sparkVersion: String, supportedVersions: String*): Unit = { diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala index b1cbe28..e11d868 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala @@ -14,7 +14,7 @@ import scala.collection.mutable.ListBuffer abstract class DataReader { val tpe: ReaderType - def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame + def read(spark: SparkSession, configs: Configs, partitionNum: Int): DataFrame } object DataReader { def make(configs: Configs): DataReader = { @@ -32,12 +32,11 @@ object DataReader { class NebulaReader extends DataReader { override val tpe: ReaderType = ReaderType.nebula - override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = { + override def read(spark: SparkSession, configs: Configs, partitionNum: Int): DataFrame = { val metaAddress = configs.nebulaConfig.readConfigEntry.address val space = configs.nebulaConfig.readConfigEntry.space val labels = configs.nebulaConfig.readConfigEntry.labels val weights = configs.nebulaConfig.readConfigEntry.weightCols - val partition = partitionNum.toInt val config = NebulaConnectionConfig @@ -60,7 +59,7 @@ class NebulaReader extends DataReader { .withLabel(labels(i)) .withNoColumn(noColumn) .withReturnCols(returnCols.toList) - .withPartitionNum(partition) + .withPartitionNum(partitionNum) .build() if (dataset == null) { dataset = spark.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF() @@ -85,13 +84,12 @@ final class NebulaNgqlReader extends NebulaReader { override val tpe: ReaderType = ReaderType.nebulaNgql - override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = { + override def read(spark: SparkSession, configs: Configs, partitionNum: Int): DataFrame = { val metaAddress = configs.nebulaConfig.readConfigEntry.address val graphAddress = configs.nebulaConfig.readConfigEntry.graphAddress val space = configs.nebulaConfig.readConfigEntry.space val labels = configs.nebulaConfig.readConfigEntry.labels val weights = configs.nebulaConfig.readConfigEntry.weightCols - val partition = partitionNum.toInt val ngql = configs.nebulaConfig.readConfigEntry.ngql val config = @@ -112,7 +110,7 @@ final class NebulaNgqlReader extends NebulaReader { .builder() .withSpace(space) .withLabel(labels(i)) - .withPartitionNum(partition) + .withPartitionNum(partitionNum) .withNgql(ngql) .build() if (dataset == null) { @@ -137,13 +135,11 @@ final class NebulaNgqlReader extends NebulaReader { final class CsvReader extends DataReader { override val tpe: ReaderType = ReaderType.csv - override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = { + override def read(spark: SparkSession, configs: Configs, partitionNum: Int): DataFrame = { val delimiter = configs.localConfigEntry.delimiter val header = configs.localConfigEntry.header val localPath = configs.localConfigEntry.filePath - val partition = partitionNum.toInt - val data = spark.read .option("header", header) @@ -157,18 +153,17 @@ final class CsvReader extends DataReader { } else { data.select(src, dst) } - if (partition != 0) { - data.repartition(partition) + if (partitionNum != 0) { + data.repartition(partitionNum) } data } } final class JsonReader extends DataReader { override val tpe: ReaderType = ReaderType.json - override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = { + override def read(spark: SparkSession, configs: Configs, partitionNum: Int): DataFrame = { val localPath = configs.localConfigEntry.filePath val data = spark.read.json(localPath) - val partition = partitionNum.toInt val weight = configs.localConfigEntry.weight val src = configs.localConfigEntry.srcId @@ -178,8 +173,8 @@ final class JsonReader extends DataReader { } else { data.select(src, dst) } - if (partition != 0) { - data.repartition(partition) + if (partitionNum != 0) { + data.repartition(partitionNum) } data } diff --git a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/config/ConfigSuite.scala b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/config/ConfigSuite.scala index 0ab82d1..8dfabc3 100644 --- a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/config/ConfigSuite.scala +++ b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/config/ConfigSuite.scala @@ -46,7 +46,7 @@ class ConfigSuite { assert(sparkConfig.map.size == 3) val spark = SparkConfig.getSpark(configs) - assert(spark.partitionNum.toInt == 100) + assert(spark.partitionNum == 100) } @Test