diff --git a/angelml/src/main/scala/org/apache/spark/angel/examples/graph/LouvainExample.scala b/angelml/src/main/scala/org/apache/spark/angel/examples/graph/LouvainExample.scala index b4dadb5..d037ad5 100644 --- a/angelml/src/main/scala/org/apache/spark/angel/examples/graph/LouvainExample.scala +++ b/angelml/src/main/scala/org/apache/spark/angel/examples/graph/LouvainExample.scala @@ -42,6 +42,9 @@ object LouvainExample { val eps = params.getOrElse("eps", "0.0").toDouble val bufferSize = params.getOrElse("bufferSize", "1000000").toInt val isWeighted = params.getOrElse("isWeighted", "false").toBoolean + val srcIndex = params.getOrElse("srcIndex", "0").toInt + val dstIndex = params.getOrElse("dstIndex", "1").toInt + val weightIndex = params.getOrElse("weightIndex", "2").toInt val psPartitionNum = params.getOrElse("psPartitionNum", sc.getConf.get("spark.ps.instances", "10")).toInt @@ -49,6 +52,12 @@ object LouvainExample { val cpDir = params.get("cpDir").filter(_.nonEmpty).orElse(GraphIO.defaultCheckpointDir) .getOrElse(throw new Exception("checkpoint dir not provided")) sc.setCheckpointDir(cpDir) + + val sep = params.getOrElse("sep", "space") match { + case "space" => " " + case "comma" => "," + case "tab" => "\t" + } val louvain = new Louvain() .setPartitionNum(partitionNum) @@ -61,10 +70,11 @@ object LouvainExample { .setBufferSize(bufferSize) .setIsWeighted(isWeighted) .setPSPartitionNum(psPartitionNum) - .setSrcNodeIdCol("src") - .setDstNodeIdCol("dst") + - val df = GraphIO.load(input, isWeighted = isWeighted) + val df = GraphIO.load(input, isWeighted = isWeighted, + srcIndex = srcIndex, dstIndex = dstIndex, + weightIndex = weightIndex, sep = sep) val mapping = louvain.transform(df) GraphIO.save(mapping, output) stop()