package com.gmei import java.io.Serializable import scala.util.Try import org.apache.spark.storage.StorageLevel import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.distributed.{IndexedRow, IndexedRowMatrix} import org.apache.spark.sql.{SaveMode, TiContext} import org.apache.log4j.{Level, Logger} import scopt.OptionParser import com.gmei.lib.AbstractParams import com.soundcloud.lsh.Lsh object Main { Logger.getLogger("org.apache.spark").setLevel(Level.WARN) Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF) case class Params(iter: Int = 10, lr: Double = 0.025, numPartition: Int = 10, dim: Int = 128, window: Int = 10, walkLength: Int = 80, numWalks: Int = 10, p: Double = 1.0, q: Double = 1.0, weighted: Boolean = true, directed: Boolean = false, degree: Int = 30, indexed: Boolean = false, env: String = ENV.DEV, nodePath: String = null ) extends AbstractParams[Params] with Serializable val defaultParams = Params() val parser = new OptionParser[Params]("Node2Vec_Spark") { head("Main") opt[Int]("walkLength") .text(s"walkLength: ${defaultParams.walkLength}") .action((x, c) => c.copy(walkLength = x)) opt[Int]("numWalks") .text(s"numWalks: ${defaultParams.numWalks}") .action((x, c) => c.copy(numWalks = x)) opt[Double]("p") .text(s"return parameter p: ${defaultParams.p}") .action((x, c) => c.copy(p = x)) opt[Double]("q") .text(s"in-out parameter q: ${defaultParams.q}") .action((x, c) => c.copy(q = x)) opt[Boolean]("weighted") .text(s"weighted: ${defaultParams.weighted}") .action((x, c) => c.copy(weighted = x)) opt[Boolean]("directed") .text(s"directed: ${defaultParams.directed}") .action((x, c) => c.copy(directed = x)) opt[Int]("degree") .text(s"degree: ${defaultParams.degree}") .action((x, c) => c.copy(degree = x)) opt[Boolean]("indexed") .text(s"Whether nodes are indexed or not: ${defaultParams.indexed}") .action((x, c) => c.copy(indexed = x)) opt[String]("env") .text(s"the databases environment you used") .action((x,c) => c.copy(env = x)) opt[String]("nodePath") .text("Input node2index file path: empty") .action((x, c) => c.copy(nodePath = x)) note( """ |For example, the following command runs this app on a tidb dataset: | | spark-submit --class com.gmei.Main ./target/scala-2.11/Node2vec-assembly-0.2.jar \ """.stripMargin + s"| --env ${defaultParams.env}" ) } def main(args: Array[String]):Unit = { parser.parse(args, defaultParams).map { param => //1. get the input and node2vec GmeiConfig.setup(param) val spark_env = GmeiConfig.getSparkSession() val context = spark_env._1 val sc = spark_env._2 val ti = new TiContext(sc) ti.tidbMapTable(dbName = GmeiConfig.config.getString("tidb.database"),tableName = "nd_data_meigou_cid") ti.tidbMapTable(dbName = GmeiConfig.config.getString("tidb.database"),tableName = "data_feed_click") val date8 = GmeiConfig.getMinusNDate(30) val tidb_inupt = sc.sql( s""" |SELECT | service_id,cid |FROM nd_data_meigou_cid |where stat_date > '${date8}' """.stripMargin ) Node2vec.setup(context, param) Node2vec.load(tidb_inupt) .initTransitionProb() .randomWalk() .embedding() val node2vector = context.parallelize(Word2vec.getVectors.toList) .map { case (nodeId, vector) => (nodeId.toLong, vector.map(x => x.toDouble)) } val id2Node = Node2vec.node2id.map{ case (strNode, index) => (index, strNode) } val node2vec_2 = node2vector.join(id2Node) .map { case (nodeId, (vector, name)) => (name,vector) } .repartition(200) //2. compute similar cid and then take top k val storageLevel = StorageLevel.MEMORY_AND_DISK val indexed = node2vec_2.zipWithIndex.persist(storageLevel) // create indexed row matrix where every row represents one word val rows = indexed.map { case ((word, features), index) => IndexedRow(index, Vectors.dense(features)) } // store index for later re-mapping (index to word) val index = indexed.map { case ((word, features), index) => (index, word) }.persist(storageLevel) // create an input matrix from all rows and run lsh on it val matrix = new IndexedRowMatrix(rows) val lsh = new Lsh( minCosineSimilarity = 0.5, dimensions = 20, numNeighbours = 200, numPermutations = 10, partitions = 200, storageLevel = storageLevel ) val similarityMatrix = lsh.join(matrix) import sc.implicits._ // remap both ids back to words val remapFirst = similarityMatrix.entries.keyBy(_.i).join(index).values val remapSecond = remapFirst.keyBy { case (entry, word1) => entry.j }.join(index).values.map { case ((entry, word1), word2) => (word1, word2, entry.value) } remapSecond.take(20).foreach(println) val score_result = remapSecond.toDF("cid1","cid2","score") GmeiConfig.writeToJDBCTable(score_result, table="nd_cid_pairs_cosine_distince", SaveMode.Overwrite) // group by neighbours to get a list of similar words and then take top k val result = remapSecond.filter(_._1.startsWith("diary")).groupBy(_._1).map { case (word1, similarWords) => // sort by score desc. and take top 20 entries val similar = Try(similarWords.toSeq.sortBy(-1 * _._3).filter(_._2.startsWith("diary")).take(20).map(_._2).mkString(",")).getOrElse(null) (word1,s"$similar") }.filter(_._2.split(",").length > 4) result.take(20).foreach(println) val similar_result = result.toDF("cid","similarity_cid") GmeiConfig.writeToJDBCTable(similar_result, table="nd_cid_similarity_matrix", SaveMode.Overwrite) //3. cids queue map to device_id ti.tidbMapTable(dbName = GmeiConfig.config.getString("tidb.database"),tableName = "nd_cid_similarity_matrix") val device_id = sc.sql( s""" |select a.device_id device_id,a.city_id city_id ,b.similarity_cid similarity_cid from |(select device_id,first(city_id) as city_id,first(cid) as cid from data_feed_click |where cid in (select cid from nd_cid_similarity_matrix) |group by device_id) a left join |nd_cid_similarity_matrix b |on a.cid = b.cid |where b.similarity_cid is not null """.stripMargin ).na.fill(Map("city_id"->"beijing")) device_id.show() val device_queue = device_id.rdd.map {item => val parts = (item.getAs[String](fieldName = "device_id"),item.getAs[String](fieldName = "city_id"),item.getAs[String](fieldName = "similarity_cid")) Try { (parts._1,Try(parts._2.toString.replace("worldwide","beijing")).getOrElse(null),Try(parts._3.toString.replace("diary|","")).getOrElse(null)) }.getOrElse(null) }.filter(_!=null).toDF("device_id","city_id","similarity_cid") device_queue.take(20).foreach(println) GmeiConfig.writeToJDBCTable(device_queue, table="nd_device_cid_similarity_matrix", SaveMode.Overwrite) sc.stop() } } getOrElse { sys.exit(1) } }