Skip to content

Commit

Permalink
Cast partitionNum to Int (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
artemkorsakov committed Jan 12, 2024
1 parent 371ffbb commit 656d01c
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ object Main {
*/
private[this] def createDataSource(spark: SparkSession,
configs: Configs,
partitionNum: String): DataFrame = {
partitionNum: Int): DataFrame = {
val dataSource = DataReader.make(configs)
dataSource.read(spark, configs, partitionNum)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package com.vesoft.nebula.algorithm.config

import org.apache.spark.sql.SparkSession

case class SparkConfig(spark: SparkSession, partitionNum: String)
case class SparkConfig(spark: SparkSession, partitionNum: Int)

object SparkConfig {

Expand All @@ -27,7 +27,7 @@ object SparkConfig {
partitionNum = sparkConfigs.getOrElse("spark.app.partitionNum", "0")
val spark = session.getOrCreate()
validate(spark.version, "2.4.*")
SparkConfig(spark, partitionNum)
SparkConfig(spark, partitionNum.toInt)
}

def validate(sparkVersion: String, supportedVersions: String*): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import scala.collection.mutable.ListBuffer

abstract class DataReader {
val tpe: ReaderType
def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame
def read(spark: SparkSession, configs: Configs, partitionNum: Int): DataFrame
}
object DataReader {
def make(configs: Configs): DataReader = {
Expand All @@ -32,12 +32,11 @@ object DataReader {

class NebulaReader extends DataReader {
override val tpe: ReaderType = ReaderType.nebula
override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = {
override def read(spark: SparkSession, configs: Configs, partitionNum: Int): DataFrame = {
val metaAddress = configs.nebulaConfig.readConfigEntry.address
val space = configs.nebulaConfig.readConfigEntry.space
val labels = configs.nebulaConfig.readConfigEntry.labels
val weights = configs.nebulaConfig.readConfigEntry.weightCols
val partition = partitionNum.toInt

val config =
NebulaConnectionConfig
Expand All @@ -60,7 +59,7 @@ class NebulaReader extends DataReader {
.withLabel(labels(i))
.withNoColumn(noColumn)
.withReturnCols(returnCols.toList)
.withPartitionNum(partition)
.withPartitionNum(partitionNum)
.build()
if (dataset == null) {
dataset = spark.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF()
Expand All @@ -85,13 +84,12 @@ final class NebulaNgqlReader extends NebulaReader {

override val tpe: ReaderType = ReaderType.nebulaNgql

override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = {
override def read(spark: SparkSession, configs: Configs, partitionNum: Int): DataFrame = {
val metaAddress = configs.nebulaConfig.readConfigEntry.address
val graphAddress = configs.nebulaConfig.readConfigEntry.graphAddress
val space = configs.nebulaConfig.readConfigEntry.space
val labels = configs.nebulaConfig.readConfigEntry.labels
val weights = configs.nebulaConfig.readConfigEntry.weightCols
val partition = partitionNum.toInt
val ngql = configs.nebulaConfig.readConfigEntry.ngql

val config =
Expand All @@ -112,7 +110,7 @@ final class NebulaNgqlReader extends NebulaReader {
.builder()
.withSpace(space)
.withLabel(labels(i))
.withPartitionNum(partition)
.withPartitionNum(partitionNum)
.withNgql(ngql)
.build()
if (dataset == null) {
Expand All @@ -137,13 +135,11 @@ final class NebulaNgqlReader extends NebulaReader {

final class CsvReader extends DataReader {
override val tpe: ReaderType = ReaderType.csv
override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = {
override def read(spark: SparkSession, configs: Configs, partitionNum: Int): DataFrame = {
val delimiter = configs.localConfigEntry.delimiter
val header = configs.localConfigEntry.header
val localPath = configs.localConfigEntry.filePath

val partition = partitionNum.toInt

val data =
spark.read
.option("header", header)
Expand All @@ -157,18 +153,17 @@ final class CsvReader extends DataReader {
} else {
data.select(src, dst)
}
if (partition != 0) {
data.repartition(partition)
if (partitionNum != 0) {
data.repartition(partitionNum)
}
data
}
}
final class JsonReader extends DataReader {
override val tpe: ReaderType = ReaderType.json
override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = {
override def read(spark: SparkSession, configs: Configs, partitionNum: Int): DataFrame = {
val localPath = configs.localConfigEntry.filePath
val data = spark.read.json(localPath)
val partition = partitionNum.toInt

val weight = configs.localConfigEntry.weight
val src = configs.localConfigEntry.srcId
Expand All @@ -178,8 +173,8 @@ final class JsonReader extends DataReader {
} else {
data.select(src, dst)
}
if (partition != 0) {
data.repartition(partition)
if (partitionNum != 0) {
data.repartition(partitionNum)
}
data
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ConfigSuite {
assert(sparkConfig.map.size == 3)

val spark = SparkConfig.getSpark(configs)
assert(spark.partitionNum.toInt == 100)
assert(spark.partitionNum == 100)
}

@Test
Expand Down

0 comments on commit 656d01c

Please sign in to comment.