diff --git a/java/src/main/scala/com/tencent/angel/pytorch/io/IOFunctions.scala b/java/src/main/scala/com/tencent/angel/pytorch/io/IOFunctions.scala index 54b8efe..209724a 100644 --- a/java/src/main/scala/com/tencent/angel/pytorch/io/IOFunctions.scala +++ b/java/src/main/scala/com/tencent/angel/pytorch/io/IOFunctions.scala @@ -129,6 +129,53 @@ object IOFunctions { df } + def loadEdgeWithLabel(input: String, isTyped: Boolean, + srcIndex: Int = 0, dstIndex: Int = 1, typeIndex: Int = 2, labelIndex: Int = 3, + sep: String = " "): DataFrame = { + val ss = SparkSession.builder().getOrCreate() + val schema = if (isTyped) { + StructType(Seq( + StructField("src", LongType, nullable = false), + StructField("dst", LongType, nullable = false), + StructField("type", IntegerType, nullable = false), + StructField("label", FloatType, nullable = false) + )) + } else { + StructType(Seq( + StructField("src", LongType, nullable = false), + StructField("dst", LongType, nullable = false), + StructField("label", FloatType, nullable = false) + )) + } + val df = ss.read + .option("sep", sep) + .option("header", "false") + .schema(schema) + .csv(input) + df.persist() + if (df.rdd.filter(row => row.get(0) != null).count() == 0) throw new AngelException("The edge format is incorrect, please check!!!") + df + + } + + def loadNodeType(input: String, nodeIndex: Int = 0, + typeIndex: Int = 1, sep: String = " "): DataFrame = { + val ss = SparkSession.builder().getOrCreate() + val schema = StructType(Seq( + StructField("node", LongType, nullable = false), + StructField("type", IntegerType, nullable = false) + )) + val df = ss.read + .option("sep", sep) + .option("header", "false") + .schema(schema) + .csv(input) + df.persist() + if (df.rdd.filter(row => row.get(0) != null).count() == 0) throw new AngelException("The type format is incorrect, please check!!!") + df + + } + def parseSep(sep: String): String = { sep match { case "space" => " "