Pregel and Shortest Path Algorithm in GraphX

If you pay enough attention to Graph.scala, you may find cache function. It can cache the graph at specific storage levels. However, GraphX doesn’t perform this function automatically, and the developers have to execute it manually. But during each iteration, we have to cache the new results and delete the useless information in order to accelerate the computation. It’s inconvenient and uncontrollable, lest the graph in GraphX is stored as vertices and edges. To solve this problem and make the iteration easier to use, GraphX provides a Pregel-like API [1].

Pregel API is provided in Pregel.scala. It’s not the same as the original Pergel API, it’s Pregel-like. It implements a bulk-synchronous message-passing API, which enables the message sending computation to read both vertex attributes, and constrains messages to the graph structure.

In apply method of object Pregel, VD is the vertex data type, ED is the edge data type, and A is the Pregel message type. Each vertex in graph will receive the initial message initialMsg at the the first iteration. The iterations will run at most maxIterations times (when there’re no remaining messages, the iteration on that vertex will stop too). activeDirection is the direction of edges incident to a vertex that received a message in the previous round on which to run sendMsg. For example, if this is EdgeDirection.Either (by default), edges where either side received a message in the previous round will run sendMsg. Then commutative associative function mergeMsg is used to merge two incoming messages on a vertex into a single message. The vertex-program vprog is like onMessage in WebSocket, which is executed in parallel when each vertex receiving any inbound messages and computing a new value for the vertex.

object Pregel extends Logging {
  def apply[VD: ClassTag, ED: ClassTag, A: ClassTag]
     (graph: Graph[VD, ED],
      initialMsg: A,
      maxIterations: Int = Int.MaxValue,
      activeDirection: EdgeDirection = EdgeDirection.Either)
     (vprog: (VertexId, VD, A) => VD,
      sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
      mergeMsg: (A, A) => A)
    : Graph[VD, ED] =
    // Loop
    var i = 0
    while (activeMessages > 0 && i < maxIterations) {
      // Receive the messages. Vertices that didn't get any messages do not appear in newVerts.
      // Update the graph with the new vertices.
      // Send new messages.
      // Materializes `messages`, `newVerts`, and the vertices of `g`. This hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g).
      // Unpersist the RDDs hidden by newly-materialized RDDs
      // count the iteration
      i += 1
    // Return new resulted graph
  } // end of apply

Pregel API can help us calculate PageRank and Shortest Path easily. We’ll introduce Shortest Path using web-Google first and PageRank in the later article.

Till now, computer scientists have proposed several algorithms for the Shortest path problem [2], such as Dijkstra’s algorithm [3] and Bellman–Ford algorithm [4]. These algorithms have different implementations, but the cores of them are the same, the relaxation operation [5]: if distance from a to c dis[a][c] is longer than the distance from a to b dis[a][b] plus that from b to c dis[b][c], then update dis[a][c] to dis[a][b] + dis[b][c].

Now begin the example. Import the graph, define the source vertex, and initialize the distance used to be iterated (we will use Dijkstra’s algorithm):

scala> import org.apache.spark._
scala> import org.apache.spark.graphx._
scala> import org.apache.spark.rdd.RDD
scala> val graph = GraphLoader.edgeListFile(sc, "hdfs://")
scala> val sourceId: VertexId = 0
scala> val g = graph.mapVertices( (id, _) =>
     |   if (id == sourceId) 0.0
     |   else Double.PositiveInfinity
     | )

Then use Pregel API simply:

scala> val sssp = g.pregel(Double.PositiveInfinity)(
     |   (id, dist, newDist) => math.min(dist, newDist),
     |   triplet => {
     |     if (triplet.srcAttr + triplet.attr < triplet.dstAttr) {
     |       Iterator((triplet.dstId, triplet.srcAttr + triplet.attr))
     |     }
     |     else {
     |       Iterator.empty
     |     }
     |   },
     |   (a, b) => math.min(a, b)
     | )

View the result of it:

scala> sssp.vertices.take(10).mkString("\n")
res0: String =

Shortest Path algorithm has been provided in [[lib/ShortestPaths.scala]]. Instead of calculating the Single Source Shortest Path (SSSP), it calculates the shortest distance between each two different vertices.

object ShortestPaths {
  def run[VD, ED: ClassTag](graph: Graph[VD, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = {
    val spGraph = graph.mapVertices { (vid, attr) =>
      if (landmarks.contains(vid)) makeMap(vid -> 0) else makeMap()

    val initialMessage = makeMap()

    def vertexProgram(id: VertexId, attr: SPMap, msg: SPMap): SPMap = {
      addMaps(attr, msg)

    def sendMessage(edge: EdgeTriplet[SPMap, _]): Iterator[(VertexId, SPMap)] = {
      val newAttr = incrementMap(edge.dstAttr)
      if (edge.srcAttr != addMaps(newAttr, edge.srcAttr)) Iterator((edge.srcId, newAttr))
      else Iterator.empty

    Pregel(spGraph, initialMessage)(vertexProgram, sendMessage, addMaps)

landmarks is the list of landmark vertex IDs on which the shortest paths will be computed. makeMap is a mapping function

private def makeMap(x: (VertexId, Int)*) = Map(x: _*)

SPMap is a map type [6] type SPMap = Map[VertexId, Int] storing the map from the vertex id of a landmark to the distance to that landmark.

addMaps chooses the minimal distance value as the VertexId -> Distance map. vertexProgram calls addMaps and does the same thing.

private def addMaps(spmap1: SPMap, spmap2: SPMap): SPMap =
  (spmap1.keySet ++ spmap2.keySet).map {
    k => k -> math.min(spmap1.getOrElse(k, Int.MaxValue), spmap2.getOrElse(k, Int.MaxValue))

incrementMap increases the distance of the next hop (as all weights of the edges are considered as 1). The direction of the iteration is a bit strange–it jumps from the destination vertex to the source vertex on a edge. But of course it doesn’t affect the result.

private def incrementMap(spmap: SPMap): SPMap = { case (v, d) => v -> (d + 1) }

Example: Dijkstra Algorithm in GraphX

Thanks for Stephen’s question. With slight modification of the codes mentioned before, we can calculate the Single Source Shortest Paths (SSSP) easily. First, import the graph as before:

import org.apache.spark._
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD

val graph = GraphLoader.edgeListFile(sc, "hdfs://")

As I don’t have a graph data with edge weights on hand, I randomize the weights of the edges (smaller than 100). Each vertex has an attribute which is a two-element array. This array includes the distance from the sourceId, and the VertexId of its previous node in the shortest path. Here is the initialization of the vertices and distance array:

import scala.util.Random.nextInt

val sourceId: VertexId = 0
val g = graph.mapVertices( (id, _) =>
  if (id == sourceId) Array(0.0, id)
  else Array(Double.PositiveInfinity, id)
).mapEdges( e => (new scala.util.Random).nextInt(100) )

Have a look at this graph:

scala> g.vertices.take(10)
res0: Array[(org.apache.spark.graphx.VertexId, Array[Double])] = Array((354796,Array(Infinity, 354796.0)), (672890,Array(Infinity, 672890.0)), (129434,Array(Infinity, 129434.0)), (194402,Array(Infinity, 194402.0)), (199516,Array(Infinity, 199516.0)), (332918,Array(Infinity, 332918.0)), (170792,Array(Infinity, 170792.0)), (386896,Array(Infinity, 386896.0)), (691634,Array(Infinity, 691634.0)), (291526,Array(Infinity, 291526.0)))
scala> g.edges.take(10)
res1: Array[org.apache.spark.graphx.Edge[Int]] = Array(Edge(0,11342,47), Edge(0,824020,53), Edge(0,867923,90), Edge(0,891835,58), Edge(1,53051,7), Edge(1,203402,50), Edge(1,223236,80), Edge(1,276233,16), Edge(1,552600,76), Edge(1,569212,47))

Modify the sssp a little by updating the previous node in the path:

val sssp = g.pregel(Array(Double.PositiveInfinity, -1))(
  (id, dist, newDist) => {
    if (dist(0) < newDist(0)) dist
    else newDist
triplet => {
  if (triplet.srcAttr(0) + triplet.attr < triplet.dstAttr(0)) {
    Iterator((triplet.dstId, Array(triplet.srcAttr(0) + triplet.attr, triplet.srcId)))
  else {
(a, b) => {
  if (a(0) < b(0)) a
  else b

Format and print the answer:

val ans: RDD[String] = =>
    "Vertex " + vertex._1 + ": distance is " + vertex._2(0) + ", previous node is Vertex " + vertex._2(1).toInt)

scala> ans.take(10).mkString("\n")
res2: String =
Vertex 354796: distance is 191.0, previous node is Vertex 283953
Vertex 672890: distance is 277.0, previous node is Vertex 781510
Vertex 129434: distance is 292.0, previous node is Vertex 119943
Vertex 194402: distance is 461.0, previous node is Vertex 446259
Vertex 199516: distance is 199.0, previous node is Vertex 458892
Vertex 332918: distance is 337.0, previous node is Vertex 89384
Vertex 170792: distance is 321.0, previous node is Vertex 138757
Vertex 386896: distance is 210.0, previous node is Vertex 580484
Vertex 691634: distance is 169.0, previous node is Vertex 400059
Vertex 291526: distance is Infinity, previous node is Vertex -1

You can also use ans.collect.foreach(println(_)) to print the full answer. With previous node‘s id, it’s easy to print the full path from the vertex sourceId to the specific vertex (in O(N) time).

Another way is to keep the full path instead of previous node in the vertices. But this method may consume too many memory resources.

A third way (by Zhouyihai-Ding in iWCT Spark group) is to do another map action. We have links like (a, b), (a, c), (b, c) now; and during each MapReduce iteration, we connect the two paths which can be connected together, e.g. (a, b, c), (a, c).

If you have any other way to obtain the full path in a distributed style, please leave me a comment.


[1] Pregel: A System for Large-Scale Graph Processing,
[2] Shortest path problem,
[3] Dijkstra’s algorithm,’s_algorithm
[4] Bellman–Ford algorithm,
[5] Shortest Paths Graph Algorithms,
[6] Maps in Scala,

16 responses on “Pregel and Shortest Path Algorithm in GraphX

  1. Stephen
    Thanks for posting this. Can you elaborate on how I can use Spark and Graphx to find the shortest path between two vertices based on edge weight?

    This looks like a good starting point, but I don’t see anything showing me how to output the path and this implementation doesn’t use the weight on the edges at all.

    1. yuhc
      Thank you for the question. The program framework is just like the codes mentioned in the post, but every time we update the edges with the length of that edges instead of 1 (map the length to triplet.attr in the first code version; or modify incrementMap in the second). I’ll try to implement it in a few days, and reply to you at that time.
      I’ll also check the reply-email-notification plugin (it may not work properly), and reply to you again we I finish this new code.
      (Also sorry for my poor English) :smile:
      1. Stephen
        That makes sense. Can you also explain how I can print out the path? I’m looking forward to seeing your code.
        1. yuhc
          I have updated the post and included the new codes (before the Reference). I implemented it in a straight way (but simple): during each iteration, update the previous node in the shortest path when the distance is updated. If you have other method, please let me know :)
        2. yuhc
          In addition, I don’t write the part of printing the full path. A while-loop can print it with the information of previous nodes. I think it’s not hard and I’ll not update it in the post.
          1. Stephen
            Thanks for updating your post to include the shortest path! That is very helpful. Instead of using an Array where the first element is the cumulative weight and the second element is the id of the previous node, I think I will use a simple case class just to make things more clear. Thank you!
          2. gooeyforms
            How can you finish this in O(N) time? I thought of cascade a tree whose root is the source, but I don’t think a while loop can generate the tree. So how you manage it? Thx!
  2. charan
    Thanks for posting . Is it possible to extend this to find the shortest k paths between two vertices based on edge weight using Graphx and pregel ?
    1. yuhc
      Yes, I think so. You can maintain three optimal values at each vertex.
  3. Pranesh
    Hello Yuhc,

    Your post is really helpful. I have the following use case. I have a map data for a given city, with the nodes (node_id, lat, lon) and edgelist (from_node,to_node, distance). I am trying to find out shortest distance between 2 nodes (Point A and Point B). My question is, if i have the edgelist with weights (as distance), given as input to graphx is that good enough OR like the documentation on graphx, should I use the function Graph(Vertex_rdd, Edge_rdd)?

    Looking forward to hear back.

    1. yuhc
      Hi, thank you for reading my post. Previously when I wanted to import an weighted graph, I modified GraphLoader (to input the third column). I think you may need input edgelist to an Edge_RDD first if you don’t want to modify the source codes, and then use Graph(VD, ED). Because I haven’t focused on this area for some time, maybe Spark has introduced more convenient way to do such thing. Good luck!
  4. Pranesh

    Can you please help me out, how to print the path(list of nodes) between 2 vertices.?


  5. Very good implementation! However I’m not sure about its correctness/efficiency: One important aspect of Dijkstra algorithm is that the next node being updated is always the unvisited one that has the smallest distance: this eliminates the possibility of backtracking (where distances to all the subnodes of a node has to be recalculated again). I notice such adaptive iteration in mapPartitions doesn’t exist in your code. Is it not necessary? Can this part be optimized further?
    1. gooeyforms
      I guess there’s no “backtracking” OR “next node being updated”, because in every iteration every vertex in the graph has been included in the computation – unless one satisfies the stop condition.
  6. Akshat Kumar
    I have written the code to print the full path but one problem I am having is convert the vertexId back to vertexLabel or vertexName so that the whole path is represented by vertexLabel or vertexName and not with vertexId’s.
    Can you tell me how to do that ?
  7. Aroan
    I tried this code to convert into the full path instead previous node but cannot able find the full path. is there anybody can help so that I can display full path from source to destination? or can tell me where I have to make changes in this code to get full path.

Leave a Reply

Time limit is exhausted. Please reload CAPTCHA.