Main.scala 7.61 KB
Newer Older
高雅喆's avatar
高雅喆 committed
1 2 3
package com.gmei

import java.io.Serializable
高雅喆's avatar
高雅喆 committed
4 5
import scala.util.Try

高雅喆's avatar
高雅喆 committed
6 7 8
import org.apache.spark.storage.StorageLevel
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.distributed.{IndexedRow, IndexedRowMatrix}
高雅喆's avatar
高雅喆 committed
9
import org.apache.spark.sql.{SaveMode, TiContext}
高雅喆's avatar
高雅喆 committed
10 11 12
import org.apache.log4j.{Level, Logger}
import scopt.OptionParser
import com.gmei.lib.AbstractParams
13 14
import com.soundcloud.lsh.Lsh

高雅喆's avatar
高雅喆 committed
15 16 17 18


object Main {

高雅喆's avatar
高雅喆 committed
19 20
  Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
  Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)
高雅喆's avatar
高雅喆 committed
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43


  case class Params(iter: Int = 10,
                    lr: Double = 0.025,
                    numPartition: Int = 10,
                    dim: Int = 128,
                    window: Int = 10,
                    walkLength: Int = 80,
                    numWalks: Int = 10,
                    p: Double = 1.0,
                    q: Double = 1.0,
                    weighted: Boolean = true,
                    directed: Boolean = false,
                    degree: Int = 30,
                    indexed: Boolean = false,
                    env: String = ENV.DEV,
                    nodePath: String = null
                   ) extends AbstractParams[Params] with Serializable
  val defaultParams = Params()

  val parser = new OptionParser[Params]("Node2Vec_Spark") {
    head("Main")
    opt[Int]("walkLength")
44 45
      .text(s"walkLength: ${defaultParams.walkLength}")
      .action((x, c) => c.copy(walkLength = x))
高雅喆's avatar
高雅喆 committed
46
    opt[Int]("numWalks")
47 48
      .text(s"numWalks: ${defaultParams.numWalks}")
      .action((x, c) => c.copy(numWalks = x))
高雅喆's avatar
高雅喆 committed
49
    opt[Double]("p")
50 51
      .text(s"return parameter p: ${defaultParams.p}")
      .action((x, c) => c.copy(p = x))
高雅喆's avatar
高雅喆 committed
52
    opt[Double]("q")
53 54
      .text(s"in-out parameter q: ${defaultParams.q}")
      .action((x, c) => c.copy(q = x))
高雅喆's avatar
高雅喆 committed
55
    opt[Boolean]("weighted")
56 57
      .text(s"weighted: ${defaultParams.weighted}")
      .action((x, c) => c.copy(weighted = x))
高雅喆's avatar
高雅喆 committed
58
    opt[Boolean]("directed")
59 60
      .text(s"directed: ${defaultParams.directed}")
      .action((x, c) => c.copy(directed = x))
高雅喆's avatar
高雅喆 committed
61
    opt[Int]("degree")
62 63
      .text(s"degree: ${defaultParams.degree}")
      .action((x, c) => c.copy(degree = x))
高雅喆's avatar
高雅喆 committed
64
    opt[Boolean]("indexed")
65 66
      .text(s"Whether nodes are indexed or not: ${defaultParams.indexed}")
      .action((x, c) => c.copy(indexed = x))
高雅喆's avatar
高雅喆 committed
67
    opt[String]("env")
68 69
      .text(s"the databases environment you used")
      .action((x,c) => c.copy(env = x))
高雅喆's avatar
高雅喆 committed
70
    opt[String]("nodePath")
71 72
      .text("Input node2index file path: empty")
      .action((x, c) => c.copy(nodePath = x))
高雅喆's avatar
高雅喆 committed
73 74
    note(
      """
高雅喆's avatar
高雅喆 committed
75
        |For example, the following command runs this app on a tidb dataset:
高雅喆's avatar
高雅喆 committed
76
        |
高雅喆's avatar
高雅喆 committed
77
        | spark-submit --class com.gmei.Main ./target/scala-2.11/Node2vec-assembly-0.2.jar \
高雅喆's avatar
高雅喆 committed
78
      """.stripMargin +
高雅喆's avatar
高雅喆 committed
79
        s"|   --env ${defaultParams.env}"
高雅喆's avatar
高雅喆 committed
80 81 82 83 84 85 86 87 88 89
    )
  }


  def main(args: Array[String]):Unit = {

    parser.parse(args, defaultParams).map { param =>

      //1. get the input and node2vec
      GmeiConfig.setup(param)
高雅喆's avatar
高雅喆 committed
90 91 92
      val spark_env = GmeiConfig.getSparkSession()
      val context = spark_env._1
      val sc = spark_env._2
高雅喆's avatar
高雅喆 committed
93 94

      val ti = new TiContext(sc)
高雅喆's avatar
高雅喆 committed
95 96 97
      ti.tidbMapTable(dbName = GmeiConfig.config.getString("tidb.database"),tableName = "nd_data_meigou_cid")
      ti.tidbMapTable(dbName = GmeiConfig.config.getString("tidb.database"),tableName = "data_feed_click")

高雅喆's avatar
高雅喆 committed
98

高雅喆's avatar
高雅喆 committed
99
      val date8 = GmeiConfig.getMinusNDate(30)
高雅喆's avatar
高雅喆 committed
100 101 102 103
      val tidb_inupt = sc.sql(
        s"""
           |SELECT
           | service_id,cid
高雅喆's avatar
高雅喆 committed
104
           |FROM nd_data_meigou_cid
高雅喆's avatar
高雅喆 committed
105
           |where stat_date > '${date8}'
高雅喆's avatar
高雅喆 committed
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
     """.stripMargin
      )

      Node2vec.setup(context, param)

      Node2vec.load(tidb_inupt)
        .initTransitionProb()
        .randomWalk()
        .embedding()

      val node2vector = context.parallelize(Word2vec.getVectors.toList)
        .map { case (nodeId, vector) =>
          (nodeId.toLong, vector.map(x => x.toDouble))
        }

      val id2Node = Node2vec.node2id.map{ case (strNode, index) =>
122 123
        (index, strNode)
      }
高雅喆's avatar
高雅喆 committed
124 125

      val node2vec_2 = node2vector.join(id2Node)
126 127
        .map { case (nodeId, (vector, name)) => (name,vector) }
        .repartition(200)
高雅喆's avatar
高雅喆 committed
128 129


130
      //2. compute similar cid and then take top k
高雅喆's avatar
高雅喆 committed
131

132
      val storageLevel = StorageLevel.MEMORY_AND_DISK
高雅喆's avatar
高雅喆 committed
133

134
      val indexed = node2vec_2.zipWithIndex.persist(storageLevel)
高雅喆's avatar
高雅喆 committed
135

136 137 138 139 140
      // create indexed row matrix where every row represents one word
      val rows = indexed.map {
        case ((word, features), index) =>
          IndexedRow(index, Vectors.dense(features))
      }
高雅喆's avatar
高雅喆 committed
141

142 143 144 145 146 147 148 149 150 151
      // store index for later re-mapping (index to word)
      val index = indexed.map {
        case ((word, features), index) =>
          (index, word)
      }.persist(storageLevel)

      // create an input matrix from all rows and run lsh on it
      val matrix = new IndexedRowMatrix(rows)
      val lsh = new Lsh(
        minCosineSimilarity = 0.5,
152
        dimensions = 20,
153 154 155 156 157 158
        numNeighbours = 200,
        numPermutations = 10,
        partitions = 200,
        storageLevel = storageLevel
      )
      val similarityMatrix = lsh.join(matrix)
高雅喆's avatar
高雅喆 committed
159 160

      import sc.implicits._
161 162
      // remap both ids back to words
      val remapFirst = similarityMatrix.entries.keyBy(_.i).join(index).values
高雅喆's avatar
高雅喆 committed
163 164


165 166 167 168
      val remapSecond = remapFirst.keyBy { case (entry, word1) => entry.j }.join(index).values.map {
        case ((entry, word1), word2) =>
          (word1, word2, entry.value)
      }
高雅喆's avatar
高雅喆 committed
169
      remapSecond.take(20).foreach(println)
高雅喆's avatar
高雅喆 committed
170

171
      val score_result = remapSecond.toDF("cid1","cid2","score")
高雅喆's avatar
高雅喆 committed
172
      GmeiConfig.writeToJDBCTable(score_result, table="nd_cid_pairs_cosine_distince", SaveMode.Overwrite)
高雅喆's avatar
高雅喆 committed
173 174


175
      // group by neighbours to get a list of similar words and then take top k
176
      val result = remapSecond.filter(_._1.startsWith("diary")).groupBy(_._1).map {
177
        case (word1, similarWords) =>
178 179
          // sort by score desc. and take top 20 entries
          val similar = Try(similarWords.toSeq.sortBy(-1 * _._3).filter(_._2.startsWith("diary")).take(20).map(_._2).mkString(",")).getOrElse(null)
180
          (word1,s"$similar")
181
      }.filter(_._2.split(",").length > 4)
高雅喆's avatar
高雅喆 committed
182
      result.take(20).foreach(println)
高雅喆's avatar
高雅喆 committed
183

184
      val similar_result = result.toDF("cid","similarity_cid")
高雅喆's avatar
高雅喆 committed
185
      GmeiConfig.writeToJDBCTable(similar_result, table="nd_cid_similarity_matrix", SaveMode.Overwrite)
高雅喆's avatar
高雅喆 committed
186

高雅喆's avatar
高雅喆 committed
187 188 189 190 191 192

      //3. cids queue map to device_id
      ti.tidbMapTable(dbName = GmeiConfig.config.getString("tidb.database"),tableName = "nd_cid_similarity_matrix")

      val device_id = sc.sql(
        s"""
高雅喆's avatar
高雅喆 committed
193
           |select a.device_id device_id,a.city_id city_id ,b.similarity_cid similarity_cid from
高雅喆's avatar
高雅喆 committed
194
           |(select device_id,first(city_id) as city_id,first(cid) as cid from data_feed_click
高雅喆's avatar
高雅喆 committed
195
           |where cid in (select cid from nd_cid_similarity_matrix)
高雅喆's avatar
高雅喆 committed
196
           |group by device_id) a left join
高雅喆's avatar
高雅喆 committed
197 198 199 200
           |nd_cid_similarity_matrix b
           |on a.cid = b.cid
           |where b.similarity_cid is not null
     """.stripMargin
高雅喆's avatar
高雅喆 committed
201
      ).na.fill(Map("city_id"->"beijing"))
高雅喆's avatar
高雅喆 committed
202 203 204
      device_id.show()

      val device_queue = device_id.rdd.map {item =>
高雅喆's avatar
高雅喆 committed
205
        val parts = (item.getAs[String](fieldName = "device_id"),item.getAs[String](fieldName = "city_id"),item.getAs[String](fieldName = "similarity_cid"))
高雅喆's avatar
高雅喆 committed
206 207

        Try {
高雅喆's avatar
高雅喆 committed
208
          (parts._1,Try(parts._2.toString.replace("worldwide","beijing")).getOrElse(null),Try(parts._3.toString.replace("diary|","")).getOrElse(null))
高雅喆's avatar
高雅喆 committed
209
        }.getOrElse(null)
高雅喆's avatar
高雅喆 committed
210
      }.filter(_!=null).toDF("device_id","city_id","similarity_cid")
高雅喆's avatar
高雅喆 committed
211 212 213

      device_queue.take(20).foreach(println)
      GmeiConfig.writeToJDBCTable(device_queue, table="nd_device_cid_similarity_matrix", SaveMode.Overwrite)
高雅喆's avatar
高雅喆 committed
214

高雅喆's avatar
高雅喆 committed
215

高雅喆's avatar
高雅喆 committed
216
      sc.stop()
高雅喆's avatar
高雅喆 committed
217 218

    }
219 220 221
  } getOrElse {
    sys.exit(1)
  }
高雅喆's avatar
高雅喆 committed
222
}