diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala index 93f7311..dc906d6 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala @@ -8,8 +8,11 @@ package com.vesoft.nebula.algorithm.config import java.io.File import java.nio.file.Files import org.apache.log4j.Logger + import scala.collection.JavaConverters._ import com.typesafe.config.{Config, ConfigFactory} +import com.vesoft.nebula.algorithm.config.Configs.readConfig + import scala.collection.mutable /** @@ -17,22 +20,8 @@ import scala.collection.mutable */ object SparkConfigEntry { def apply(config: Config): SparkConfigEntry = { - val map = mutable.Map[String, String]() - val sparkConfig = config.getObject("spark") - for (key <- sparkConfig.unwrapped().keySet().asScala) { - val sparkKey = s"spark.${key}" - if (config.getAnyRef(sparkKey).isInstanceOf[String]) { - val sparkValue = config.getString(sparkKey) - map += sparkKey -> sparkValue - } else { - for (subKey <- config.getObject(sparkKey).unwrapped().keySet().asScala) { - val key = s"${sparkKey}.${subKey}" - val sparkValue = config.getString(key) - map += key -> sparkValue - } - } - } - SparkConfigEntry(map.toMap) + val map = readConfig(config, "spark") + SparkConfigEntry(map) } } @@ -41,22 +30,8 @@ object SparkConfigEntry { */ object AlgorithmConfigEntry { def apply(config: Config): AlgorithmConfigEntry = { - val map = mutable.Map[String, String]() - val algoConfig = config.getObject("algorithm") - for (key <- algoConfig.unwrapped().keySet().asScala) { - val algorithmKey = s"algorithm.${key}" - if (config.getAnyRef(algorithmKey).isInstanceOf[String]) { - val algorithmValue = config.getString(algorithmKey) - map += algorithmKey -> algorithmValue - } else { - for (subkey <- config.getObject(algorithmKey).unwrapped().keySet().asScala) { - val key = s"${algorithmKey}.${subkey}" - val value = config.getString(key) - map += key -> value - } - } - } - AlgorithmConfigEntry(map.toMap) + val map = readConfig(config, "algorithm") + AlgorithmConfigEntry(map) } } @@ -365,6 +340,24 @@ object Configs { } parser.parse(args, Argument()) } + + def readConfig(config: Config, name: String): Map[String, String] = { + val map = mutable.Map[String, String]() + val configObject = config.getObject(name) + for (key <- configObject.unwrapped().keySet().asScala) { + val refinedKey = s"$name.$key" + config.getAnyRef(refinedKey) match { + case stringValue: String => map += refinedKey -> stringValue + case _ => + for (subKey <- config.getObject(refinedKey).unwrapped().keySet().asScala) { + val refinedSubKey = s"$refinedKey.$subKey" + val refinedSubValue = config.getString(refinedSubKey) + map += refinedSubKey -> refinedSubValue + } + } + } + map.toMap + } } object AlgoConstants {