package com.gmei

import java.io.Serializable
import scala.util.Try

import org.apache.spark.storage.StorageLevel
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.distributed.{IndexedRow, IndexedRowMatrix}
import org.apache.spark.sql.{SaveMode, TiContext}
import org.apache.log4j.{Level, Logger}
import scopt.OptionParser
import com.gmei.lib.AbstractParams
import com.soundcloud.lsh.Lsh



object Main {

  Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
  Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)


  case class Params(iter: Int = 10,
                    lr: Double = 0.025,
                    numPartition: Int = 10,
                    dim: Int = 128,
                    window: Int = 10,
                    walkLength: Int = 80,
                    numWalks: Int = 10,
                    p: Double = 1.0,
                    q: Double = 1.0,
                    weighted: Boolean = true,
                    directed: Boolean = false,
                    degree: Int = 30,
                    indexed: Boolean = false,
                    env: String = ENV.DEV,
                    nodePath: String = null
                   ) extends AbstractParams[Params] with Serializable
  val defaultParams = Params()

  val parser = new OptionParser[Params]("Node2Vec_Spark") {
    head("Main")
    opt[Int]("walkLength")
      .text(s"walkLength: ${defaultParams.walkLength}")
      .action((x, c) => c.copy(walkLength = x))
    opt[Int]("numWalks")
      .text(s"numWalks: ${defaultParams.numWalks}")
      .action((x, c) => c.copy(numWalks = x))
    opt[Double]("p")
      .text(s"return parameter p: ${defaultParams.p}")
      .action((x, c) => c.copy(p = x))
    opt[Double]("q")
      .text(s"in-out parameter q: ${defaultParams.q}")
      .action((x, c) => c.copy(q = x))
    opt[Boolean]("weighted")
      .text(s"weighted: ${defaultParams.weighted}")
      .action((x, c) => c.copy(weighted = x))
    opt[Boolean]("directed")
      .text(s"directed: ${defaultParams.directed}")
      .action((x, c) => c.copy(directed = x))
    opt[Int]("degree")
      .text(s"degree: ${defaultParams.degree}")
      .action((x, c) => c.copy(degree = x))
    opt[Boolean]("indexed")
      .text(s"Whether nodes are indexed or not: ${defaultParams.indexed}")
      .action((x, c) => c.copy(indexed = x))
    opt[String]("env")
      .text(s"the databases environment you used")
      .action((x,c) => c.copy(env = x))
    opt[String]("nodePath")
      .text("Input node2index file path: empty")
      .action((x, c) => c.copy(nodePath = x))
    note(
      """
        |For example, the following command runs this app on a tidb dataset:
        |
        | spark-submit --class com.gmei.Main ./target/scala-2.11/Node2vec-assembly-0.2.jar \
      """.stripMargin +
        s"|   --env ${defaultParams.env}"
    )
  }


  def main(args: Array[String]):Unit = {

    parser.parse(args, defaultParams).map { param =>

      //1. get the input and node2vec
      GmeiConfig.setup(param)
      val spark_env = GmeiConfig.getSparkSession()
      val context = spark_env._1
      val sc = spark_env._2

      val ti = new TiContext(sc)
      ti.tidbMapTable(dbName = GmeiConfig.config.getString("tidb.database"),tableName = "nd_data_meigou_cid")
      ti.tidbMapTable(dbName = GmeiConfig.config.getString("tidb.database"),tableName = "data_feed_click")


      val date8 = GmeiConfig.getMinusNDate(30)
      val tidb_inupt = sc.sql(
        s"""
           |SELECT
           | service_id,cid
           |FROM nd_data_meigou_cid
           |where stat_date > '${date8}'
     """.stripMargin
      )

      Node2vec.setup(context, param)

      Node2vec.load(tidb_inupt)
        .initTransitionProb()
        .randomWalk()
        .embedding()

      val node2vector = context.parallelize(Word2vec.getVectors.toList)
        .map { case (nodeId, vector) =>
          (nodeId.toLong, vector.map(x => x.toDouble))
        }

      val id2Node = Node2vec.node2id.map{ case (strNode, index) =>
        (index, strNode)
      }

      val node2vec_2 = node2vector.join(id2Node)
        .map { case (nodeId, (vector, name)) => (name,vector) }
        .repartition(200)


      //2. compute similar cid and then take top k

      val storageLevel = StorageLevel.MEMORY_AND_DISK

      val indexed = node2vec_2.zipWithIndex.persist(storageLevel)

      // create indexed row matrix where every row represents one word
      val rows = indexed.map {
        case ((word, features), index) =>
          IndexedRow(index, Vectors.dense(features))
      }

      // store index for later re-mapping (index to word)
      val index = indexed.map {
        case ((word, features), index) =>
          (index, word)
      }.persist(storageLevel)

      // create an input matrix from all rows and run lsh on it
      val matrix = new IndexedRowMatrix(rows)
      val lsh = new Lsh(
        minCosineSimilarity = 0.5,
        dimensions = 20,
        numNeighbours = 200,
        numPermutations = 10,
        partitions = 200,
        storageLevel = storageLevel
      )
      val similarityMatrix = lsh.join(matrix)

      import sc.implicits._
      // remap both ids back to words
      val remapFirst = similarityMatrix.entries.keyBy(_.i).join(index).values


      val remapSecond = remapFirst.keyBy { case (entry, word1) => entry.j }.join(index).values.map {
        case ((entry, word1), word2) =>
          (word1, word2, entry.value)
      }
      remapSecond.take(20).foreach(println)

      val score_result = remapSecond.toDF("cid1","cid2","score")
      GmeiConfig.writeToJDBCTable(score_result, table="nd_cid_pairs_cosine_distince", SaveMode.Overwrite)


      // group by neighbours to get a list of similar words and then take top k
      val result = remapSecond.filter(_._1.startsWith("diary")).groupBy(_._1).map {
        case (word1, similarWords) =>
          // sort by score desc. and take top 20 entries
          val similar = Try(similarWords.toSeq.sortBy(-1 * _._3).filter(_._2.startsWith("diary")).take(20).map(_._2).mkString(",")).getOrElse(null)
          (word1,s"$similar")
      }.filter(_._2.split(",").length > 4)
      result.take(20).foreach(println)

      val similar_result = result.toDF("cid","similarity_cid")
      GmeiConfig.writeToJDBCTable(similar_result, table="nd_cid_similarity_matrix", SaveMode.Overwrite)


      //3. cids queue map to device_id
      ti.tidbMapTable(dbName = GmeiConfig.config.getString("tidb.database"),tableName = "nd_cid_similarity_matrix")

      val device_id = sc.sql(
        s"""
           |select a.device_id device_id,a.city_id city_id ,b.similarity_cid similarity_cid from
           |(select device_id,first(city_id) as city_id,first(cid) as cid from data_feed_click
           |where cid in (select cid from nd_cid_similarity_matrix)
           |group by device_id) a left join
           |nd_cid_similarity_matrix b
           |on a.cid = b.cid
           |where b.similarity_cid is not null
     """.stripMargin
      ).na.fill(Map("city_id"->"beijing"))
      device_id.show()

      val device_queue = device_id.rdd.map {item =>
        val parts = (item.getAs[String](fieldName = "device_id"),item.getAs[String](fieldName = "city_id"),item.getAs[String](fieldName = "similarity_cid"))

        Try {
          (parts._1,Try(parts._2.toString.replace("worldwide","beijing")).getOrElse(null),Try(parts._3.toString.replace("diary|","")).getOrElse(null))
        }.getOrElse(null)
      }.filter(_!=null).toDF("device_id","city_id","similarity_cid")

      device_queue.take(20).foreach(println)
      GmeiConfig.writeToJDBCTable(device_queue, table="nd_device_cid_similarity_matrix", SaveMode.Overwrite)


      sc.stop()

    }
  } getOrElse {
    sys.exit(1)
  }
}