Node2vec.scala 8.87 KB
Newer Older
高雅喆's avatar
高雅喆 committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
package com.gmei


import java.io.Serializable

import scala.util.Try
import scala.collection.mutable.ArrayBuffer
import org.slf4j.{Logger, LoggerFactory}
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.graphx.{EdgeTriplet, Graph, _}
import com.gmei.graph.{EdgeAttr, GraphOps, NodeAttr}
import org.apache.spark.sql.DataFrame

object Node2vec extends Serializable {
  lazy val logger: Logger = LoggerFactory.getLogger(getClass.getName);
17

高雅喆's avatar
高雅喆 committed
18 19 20 21 22 23 24 25 26 27 28
  var context: SparkContext = null
  var config: Main.Params = null
  var node2id: RDD[(String, Long)] = null
  var indexedEdges: RDD[Edge[EdgeAttr]] = _
  var indexedNodes: RDD[(VertexId, NodeAttr)] = _
  var graph: Graph[NodeAttr, EdgeAttr] = _
  var randomWalkPaths: RDD[(Long, ArrayBuffer[Long])] = null

  def setup(context: SparkContext, param: Main.Params): this.type = {
    this.context = context
    this.config = param
29

高雅喆's avatar
高雅喆 committed
30 31
    this
  }
32

高雅喆's avatar
高雅喆 committed
33 34 35 36 37 38 39 40
  def load(tidb_input: DataFrame): this.type = {
    val bcMaxDegree = context.broadcast(config.degree)
    val bcEdgeCreator = config.directed match {
      case true => context.broadcast(GraphOps.createDirectedEdge)
      case false => context.broadcast(GraphOps.createUndirectedEdge)
    }

    val inputTriplets: RDD[(Long, Long, Double)] = config.indexed match {
41
      //      case true => readIndexedGraph(config.input)
高雅喆's avatar
高雅喆 committed
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
      case false => indexingGraph(tidb_input)
    }

    indexedNodes = inputTriplets.flatMap { case (srcId, dstId, weight) =>
      bcEdgeCreator.value.apply(srcId, dstId, weight)
    }.reduceByKey(_++_).map { case (nodeId, neighbors: Array[(VertexId, Double)]) =>
      var neighbors_ = neighbors
      if (neighbors_.length > bcMaxDegree.value) {
        neighbors_ = neighbors.sortWith{ case (left, right) => left._2 > right._2 }.slice(0, bcMaxDegree.value)
      }

      (nodeId, NodeAttr(neighbors = neighbors_.distinct))
    }.repartition(200).cache

    indexedEdges = indexedNodes.flatMap { case (srcId, clickNode) =>
      clickNode.neighbors.map { case (dstId, weight) =>
58
        Edge(srcId, dstId, EdgeAttr())
高雅喆's avatar
高雅喆 committed
59 60 61 62 63
      }
    }.repartition(200).cache

    this
  }
64

高雅喆's avatar
高雅喆 committed
65 66 67
  def initTransitionProb(): this.type = {
    val bcP = context.broadcast(config.p)
    val bcQ = context.broadcast(config.q)
68

高雅喆's avatar
高雅喆 committed
69
    graph = Graph(indexedNodes, indexedEdges)
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
      .mapVertices[NodeAttr] { case (vertexId, clickNode) =>
      val (j, q) = GraphOps.setupAlias(clickNode.neighbors)
      val nextNodeIndex = GraphOps.drawAlias(j, q)
      clickNode.path = Array(vertexId, clickNode.neighbors(nextNodeIndex)._1)

      clickNode
    }
      .mapTriplets { edgeTriplet: EdgeTriplet[NodeAttr, EdgeAttr] =>
        val (j, q) = GraphOps.setupEdgeAlias(bcP.value, bcQ.value)(edgeTriplet.srcId, edgeTriplet.srcAttr.neighbors, edgeTriplet.dstAttr.neighbors)
        edgeTriplet.attr.J = j
        edgeTriplet.attr.q = q
        edgeTriplet.attr.dstNeighbors = edgeTriplet.dstAttr.neighbors.map(_._1)

        edgeTriplet.attr
      }.cache

高雅喆's avatar
高雅喆 committed
86 87
    this
  }
88

高雅喆's avatar
高雅喆 committed
89 90 91 92 93
  def randomWalk(): this.type = {
    val edge2attr = graph.triplets.map { edgeTriplet =>
      (s"${edgeTriplet.srcId}${edgeTriplet.dstId}", edgeTriplet.attr)
    }.repartition(200).cache
    edge2attr.first
94

高雅喆's avatar
高雅喆 committed
95 96 97 98 99
    for (iter <- 0 until config.numWalks) {
      var prevWalk: RDD[(Long, ArrayBuffer[Long])] = null
      var randomWalk = graph.vertices.map { case (nodeId, clickNode) =>
        val pathBuffer = new ArrayBuffer[Long]()
        pathBuffer.append(clickNode.path:_*)
100
        (nodeId, pathBuffer)
高雅喆's avatar
高雅喆 committed
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
      }.cache
      var activeWalks = randomWalk.first
      graph.unpersist(blocking = false)
      graph.edges.unpersist(blocking = false)
      for (walkCount <- 0 until config.walkLength) {
        prevWalk = randomWalk
        randomWalk = randomWalk.map { case (srcNodeId, pathBuffer) =>
          val prevNodeId = pathBuffer(pathBuffer.length - 2)
          val currentNodeId = pathBuffer.last

          (s"$prevNodeId$currentNodeId", (srcNodeId, pathBuffer))
        }.join(edge2attr).map { case (edge, ((srcNodeId, pathBuffer), attr)) =>
          try {
            val nextNodeIndex = GraphOps.drawAlias(attr.J, attr.q)
            val nextNodeId = attr.dstNeighbors(nextNodeIndex)
            pathBuffer.append(nextNodeId)

            (srcNodeId, pathBuffer)
          } catch {
            case e: Exception => throw new RuntimeException(e.getMessage)
          }
        }.cache

        activeWalks = randomWalk.first()
        prevWalk.unpersist(blocking=false)
      }


      if (randomWalkPaths != null) {
        val prevRandomWalkPaths = randomWalkPaths
        randomWalkPaths = randomWalkPaths.union(randomWalk).cache()
        randomWalkPaths.first
        prevRandomWalkPaths.unpersist(blocking = false)
      } else {
        randomWalkPaths = randomWalk
      }
    }
138

高雅喆's avatar
高雅喆 committed
139 140
    this
  }
141

高雅喆's avatar
高雅喆 committed
142 143 144 145
  def embedding(): this.type = {
    val randomPaths = randomWalkPaths.map { case (vertexId, pathBuffer) =>
      Try(pathBuffer.map(_.toString).toIterable).getOrElse(null)
    }.filter(_!=null)
146

高雅喆's avatar
高雅喆 committed
147
    Word2vec.setup(context, config).fit(randomPaths)
148

高雅喆's avatar
高雅喆 committed
149 150
    this
  }
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199

  //  def save(): this.type = {
  //    this.saveRandomPath()
  //        .saveModel()
  //        .saveVectors()
  //  }

  //  def saveRandomPath(): this.type = {
  //    randomWalkPaths
  //            .map { case (vertexId, pathBuffer) =>
  //              Try(pathBuffer.mkString("\t")).getOrElse(null)
  //            }
  //            .filter(x => x != null && x.replaceAll("\\s", "").length > 0)
  //            .repartition(200)
  //            .saveAsTextFile(config.output)
  //
  //    this
  //  }
  //
  //  def saveModel(): this.type = {
  //    Word2vec.save(config.output)
  //
  //    this
  //  }
  //
  //  def saveVectors(): this.type = {
  //    val node2vector = context.parallelize(Word2vec.getVectors.toList)
  //            .map { case (nodeId, vector) =>
  //              (nodeId.toLong, vector.mkString(","))
  //            }
  //
  //    if (this.node2id != null) {
  //      val id2Node = this.node2id.map{ case (strNode, index) =>
  //        (index, strNode)
  //      }
  //
  //      node2vector.join(id2Node)
  //              .map { case (nodeId, (vector, name)) => s"$name\t$vector" }
  //              .repartition(200)
  //              .saveAsTextFile(s"${config.output}.emb")
  //    } else {
  //      node2vector.map { case (nodeId, vector) => s"$nodeId\t$vector" }
  //              .repartition(200)
  //              .saveAsTextFile(s"${config.output}.emb")
  //    }
  //
  //    this
  //  }
  //
高雅喆's avatar
高雅喆 committed
200 201 202 203 204 205
  def cleanup(): this.type = {
    node2id.unpersist(blocking = false)
    indexedEdges.unpersist(blocking = false)
    indexedNodes.unpersist(blocking = false)
    graph.unpersist(blocking = false)
    randomWalkPaths.unpersist(blocking = false)
206

高雅喆's avatar
高雅喆 committed
207 208 209 210 211 212 213 214 215 216 217
    this
  }

  def loadNode2Id(node2idPath: String): this.type = {
    try {
      this.node2id = context.textFile(config.nodePath).map { node2index =>
        val Array(strNode, index) = node2index.split("\\s")
        (strNode, index.toLong)
      }
    } catch {
      case e: Exception => logger.info("Failed to read node2index file.")
218
        this.node2id = null
高雅喆's avatar
高雅喆 committed
219 220 221 222
    }

    this
  }
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247

  //  def readIndexedGraph(tripletPath: String) = {
  //    val bcWeighted = context.broadcast(config.weighted)
  //
  //    val rawTriplets = context.textFile(tripletPath)
  //    if (config.nodePath == null) {
  //      this.node2id = createNode2Id(rawTriplets.map { triplet =>
  //        val parts = triplet.split("\\s")
  //        (parts.head, parts(1), -1)
  //      })
  //    } else {
  //      loadNode2Id(config.nodePath)
  //    }
  //
  //    rawTriplets.map { triplet =>
  //      val parts = triplet.split("\\s")
  //      val weight = bcWeighted.value match {
  //        case true => Try(parts.last.toDouble).getOrElse(1.0)
  //        case false => 1.0
  //      }
  //
  //      (parts.head.toLong, parts(1).toLong, weight)
  //    }
  //  }

高雅喆's avatar
高雅喆 committed
248 249 250 251 252 253 254 255 256 257

  def indexingGraph(tidb_input: DataFrame): RDD[(Long, Long, Double)] = {
    val rawEdges = tidb_input.rdd.map { triplet =>

      val parts = (triplet.getAs[String]("service_id"), triplet.getAs[String]("cid"))

      Try {
        (parts._1, parts._2, Try(parts._2.toDouble).getOrElse(1.0))
      }.getOrElse(null)
    }.filter(_!=null)
258

高雅喆's avatar
高雅喆 committed
259
    this.node2id = createNode2Id(rawEdges)
260

高雅喆's avatar
高雅喆 committed
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
    rawEdges.map { case (src, dst, weight) =>
      (src, (dst, weight))
    }.join(node2id).map { case (src, (edge: (String, Double), srcIndex: Long)) =>
      try {
        val (dst: String, weight: Double) = edge
        (dst, (srcIndex, weight))
      } catch {
        case e: Exception => null
      }
    }.filter(_!=null).join(node2id).map { case (dst, (edge: (Long, Double), dstIndex: Long)) =>
      try {
        val (srcIndex, weight) = edge
        (srcIndex, dstIndex, weight)
      } catch {
        case e: Exception => null
      }
    }.filter(_!=null)
  }
279

高雅喆's avatar
高雅喆 committed
280 281 282 283 284
  def createNode2Id[T <: Any](triplets: RDD[(String, String, T)]) = triplets.flatMap { case (src, dst, weight) =>
    Try(Array(src, dst)).getOrElse(Array.empty[String])
  }.distinct().zipWithIndex()

}