Commit 2a1ff68f authored by 张彦钊's avatar 张彦钊

Merge branch 'master' of git.wanmeizhensuo.com:ML/ffm-baseline

增加数据库写入脚本
parents ee38e8f1 3203f457
......@@ -8,7 +8,7 @@ def con_sql(sql):
:type sql : str
:rtype : tuple
"""
db = pymysql.connect(host='10.66.157.22', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
db = pymysql.connect(host='10.66.157.22', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_prod')
cursor = db.cursor()
cursor.execute(sql)
result = cursor.fetchall()
......
*.class
*.log
build.sbt_back
# sbt specific
dist/*
target/
lib_managed/
src_managed/
project/boot/
project/plugins/project/
sbt/*.jar
mini-complete-example/sbt/*.jar
spark-warehouse/
# Scala-IDE specific
.scala_dependencies
#Emacs
*~
#ignore the metastore
metastore_db/*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
# C extensions
*.so
# Distribution / packaging
.env
.Python
env/bin/
build/*.jar
develop-eggs/
dist/
eggs/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.cache
nosetests.xml
coverage.xml
# Translations
*.mo
# Mr Developer
.mr.developer.cfg
.project
.pydevproject
# Rope
.ropeproject
# Django stuff:
*.log
*.pot
# Sphinx documentation
docs/_build/
# PyCharm files
*.idea
# emacs stuff
# Autoenv
.env
*~
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
# C extensions
*.so
# Distribution / packaging
.env
.Python
env/
bin/
build/
develop-eggs/
dist/
eggs/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.cache
nosetests.xml
coverage.xml
# Translations
*.mo
# Mr Developer
.mr.developer.cfg
.project
.pydevproject
# Rope
.ropeproject
# Django stuff:
*.log
*.pot
# Sphinx documentation
docs/_build/
# PyCharm files
*.idea
# emacs stuff
\#*\#
\.\#*
# Autoenv
.env
*~
name := """Node2vec"""
lazy val commonSettings = Seq(
version := "0.2",
organization := "com.gmei",
scalaVersion := "2.11.8",
test in assembly := {}
)
autoScalaLibrary := false
val sparkVersion = "2.2.1"
libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-core" % sparkVersion,
"org.apache.spark" %% "spark-sql" % sparkVersion,
"org.apache.spark" %% "spark-hive" % sparkVersion,
"org.apache.spark" %% "spark-streaming" % sparkVersion,
"org.apache.spark" %% "spark-streaming-kafka-0-10" % sparkVersion,
"org.apache.spark" %% "spark-mllib" % sparkVersion,
"mysql" % "mysql-connector-java" % "5.1.38",
"com.typesafe" % "config" % "1.3.2",
"org.apache.logging.log4j" % "log4j-scala" % "11.0" pomOnly(),
"org.scalatest" %% "scalatest" % "3.0.5" % "test",
"com.github.nscala-time" %% "nscala-time" % "2.18.0",
"com.github.scopt" %% "scopt" % "3.7.0",
"com.google.guava" % "guava" % "19.0"
)
lazy val root = (project in file(".")).settings(commonSettings: _*)
assemblyMergeStrategy in assembly := {
case PathList("META-INF", xs @ _*) => MergeStrategy.discard
case x => MergeStrategy.first
}
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.6")
\ No newline at end of file
dev.tidb.jdbcuri=jdbc:mysql://192.168.15.12:4000/jerry_test?user=root&password=&rewriteBatchedStatements=true
dev.tispark.pd.addresses=192.168.15.11:2379
dev.mimas.jdbcuri= jdbc:mysql://rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com/mimas_test?user=work&password=workwork&rewriteBatchedStatements=true
dev.gaia.jdbcuri=jdbc:mysql://rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com/zhengxing_test?user=work&password=workwork&rewriteBatchedStatements=true
dev.gold.jdbcuri=jdbc:mysql://rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com/doris_test?user=work&password=workwork&rewriteBatchedStatements=true
dev.tidb.database=jerry_test
pre.tidb.jdbcuri=jdbc:mysql://192.168.16.11:4000/eagle?user=root&password=&rewriteBatchedStatements=true
pre.tispark.pd.addresses=192.168.16.11:2379
pre.mimas.jdbcuri=jdbc:mysql://rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com:3308/mimas_prod?user=mimas&password=workwork&rewriteBatchedStatements=true
prod.tidb.database=jerry_prod
prod.tidb.jdbcuri=jdbc:mysql://10.66.157.22:4000/jerry_prod?user=root&password=3SYz54LS9#^9sBvC&rewriteBatchedStatements=true
prod.gold.jdbcuri=jdbc:mysql://rm-m5e842126ng59jrv6.mysql.rds.aliyuncs.com/doris_prod?user=doris&password=o5gbA27hXHHm&rewriteBatchedStatements=true
prod.mimas.jdbcuri=jdbc:mysql://rm-m5emg41za2w7l6au3.mysql.rds.aliyuncs.com/mimas_prod?user=mimas&password=GJL3UJe1Ck9ggL6aKnZCq4cRvM&rewriteBatchedStatements=true
prod.gaia.jdbcuri=jdbc:mysql://rdsfewzdmf0jfjp9un8xj.mysql.rds.aliyuncs.com/zhengxing?user=work&password=BJQaT9VzDcuPBqkd&rewriteBatchedStatements=true
prod.tispark.pd.addresses=10.66.157.22:2379
prod.redis.host=10.30.50.58
prod.redis.port=6379
\ No newline at end of file
appender.out.type = Console
appender.out.name = out
appender.out.layout.type = PatternLayout
appender.out.layout.pattern = [%30.30t] %-30.30c{1} %-5p %m%n
logger.springframework.name = org.springframework
logger.springframework.level = WARN
rootLogger.level = INFO
rootLogger.appenderRef.out.ref = out
package com.gmei
object ENV extends Enumeration {
type ENV = String
val PROD = "prod"
val DEV = "dev"
val PRE = "pre"
}
package com.gmei
import java.util.Properties
import java.io.Serializable
import com.typesafe.config._
import org.apache.spark.{SparkConf,SparkContext}
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession, TiContext}
import com.gmei.ENV.ENV
object GmeiConfig extends Serializable {
var param1: Main.Params = null
var env: String = null
var config: Config = null
def setup(param: Main.Params): this.type = {
this.param1 = param
this.env = this.param1.env match {
case "prod" => ENV.PROD
case "dev" => ENV.DEV
case "pre" => ENV.PRE
case _ => ENV.DEV
}
this.config = initConfig(this.env)
this
}
//这种情况下的param1的值为null,所以会出现空指针异常的错误
// val env = param1.env match {
// case "prod" => ENV.PROD
// case "dev" => ENV.DEV
// case "pre" => ENV.PRE
// case _ => ENV.DEV
// }
// val config = initConfig(env)
def initConfig(env: ENV) = {
lazy val c = ConfigFactory.load()
c.getConfig(env).withFallback(c)
}
def getSparkSession():(SparkContext, SparkSession) = {
val sparkConf = new SparkConf
sparkConf.set("spark.sql.crossJoin.enabled", "true")
if (!sparkConf.contains("spark.master")) {
sparkConf.setMaster("local[3]")
}
if (!sparkConf.contains("spark.tispark.pd.addresses")) {
sparkConf.set("spark.tispark.pd.addresses", this.config.getString("tispark.pd.addresses"))
}
println(sparkConf.get("spark.tispark.pd.addresses"))
val spark = SparkSession
.builder()
.config(sparkConf)
.appName("node2vec")
.getOrCreate()
val context = SparkContext.getOrCreate(sparkConf)
(context, spark)
}
def writeToJDBCTable(jdbcuri: String, df: DataFrame, table: String, saveModel: SaveMode): Unit = {
println(jdbcuri, table)
val prop = new Properties()
prop.put("driver", "com.mysql.jdbc.Driver")
prop.put("useSSL", "false")
prop.put("isolationLevel", "NONE")
prop.put("truncate", "true")
// save to mysql/tidb
df.repartition(128).write.mode(saveModel)
.option(JDBCOptions.JDBC_BATCH_INSERT_SIZE, 300)
.jdbc(jdbcuri, table, prop)
}
def writeToJDBCTable(df: DataFrame, table: String, saveMode: SaveMode): Unit = {
val jdbcuri = this.config.getString("tidb.jdbcuri")
println(jdbcuri, table)
writeToJDBCTable(jdbcuri, df, table, saveMode)
}
}
package com.gmei
import java.io.Serializable
import org.apache.spark.storage.StorageLevel
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.distributed.{IndexedRow, IndexedRowMatrix}
import org.apache.spark.sql.{SaveMode,TiContext}
import org.apache.log4j.{Level, Logger}
import scopt.OptionParser
import com.gmei.lib.AbstractParams
import com.soundcloud.lsh.Lsh
object Main {
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)
case class Params(iter: Int = 10,
lr: Double = 0.025,
numPartition: Int = 10,
dim: Int = 128,
window: Int = 10,
walkLength: Int = 80,
numWalks: Int = 10,
p: Double = 1.0,
q: Double = 1.0,
weighted: Boolean = true,
directed: Boolean = false,
degree: Int = 30,
indexed: Boolean = false,
env: String = ENV.DEV,
nodePath: String = null
) extends AbstractParams[Params] with Serializable
val defaultParams = Params()
val parser = new OptionParser[Params]("Node2Vec_Spark") {
head("Main")
opt[Int]("walkLength")
.text(s"walkLength: ${defaultParams.walkLength}")
.action((x, c) => c.copy(walkLength = x))
opt[Int]("numWalks")
.text(s"numWalks: ${defaultParams.numWalks}")
.action((x, c) => c.copy(numWalks = x))
opt[Double]("p")
.text(s"return parameter p: ${defaultParams.p}")
.action((x, c) => c.copy(p = x))
opt[Double]("q")
.text(s"in-out parameter q: ${defaultParams.q}")
.action((x, c) => c.copy(q = x))
opt[Boolean]("weighted")
.text(s"weighted: ${defaultParams.weighted}")
.action((x, c) => c.copy(weighted = x))
opt[Boolean]("directed")
.text(s"directed: ${defaultParams.directed}")
.action((x, c) => c.copy(directed = x))
opt[Int]("degree")
.text(s"degree: ${defaultParams.degree}")
.action((x, c) => c.copy(degree = x))
opt[Boolean]("indexed")
.text(s"Whether nodes are indexed or not: ${defaultParams.indexed}")
.action((x, c) => c.copy(indexed = x))
opt[String]("env")
.text(s"the databases environment you used")
.action((x,c) => c.copy(env = x))
opt[String]("nodePath")
.text("Input node2index file path: empty")
.action((x, c) => c.copy(nodePath = x))
note(
"""
|For example, the following command runs this app on a synthetic dataset:
|
| bin/spark-submit --class com.nhn.sunny.vegapunk.ml.model.Node2vec \
""".stripMargin +
s"| --lr ${defaultParams.lr}" +
s"| --iter ${defaultParams.iter}" +
s"| --numPartition ${defaultParams.numPartition}" +
s"| --dim ${defaultParams.dim}" +
s"| --window ${defaultParams.window}" +
s"| --node <nodeFilePath>" +
s"| --output <path>"
)
}
def main(args: Array[String]):Unit = {
parser.parse(args, defaultParams).map { param =>
//1. get the input and node2vec
GmeiConfig.setup(param)
val context = GmeiConfig.getSparkSession()._1
val sc = GmeiConfig.getSparkSession()._2
val ti = new TiContext(sc)
ti.tidbMapTable(dbName = GmeiConfig.config.getString("tidb.database"),tableName = "data_meigou_cid")
val tidb_inupt = sc.sql(
s"""
|SELECT
| service_id,cid
|FROM data_meigou_cid
""".stripMargin
)
Node2vec.setup(context, param)
Node2vec.load(tidb_inupt)
.initTransitionProb()
.randomWalk()
.embedding()
val node2vector = context.parallelize(Word2vec.getVectors.toList)
.map { case (nodeId, vector) =>
(nodeId.toLong, vector.map(x => x.toDouble))
}
val id2Node = Node2vec.node2id.map{ case (strNode, index) =>
(index, strNode)
}
println("get id2node")
println(id2Node.first())
val node2vec_2 = node2vector.join(id2Node)
.map { case (nodeId, (vector, name)) => (name,vector) }
.repartition(200)
//2. compute similar cid and then take top k
val storageLevel = StorageLevel.MEMORY_AND_DISK
val indexed = node2vec_2.zipWithIndex.persist(storageLevel)
// create indexed row matrix where every row represents one word
val rows = indexed.map {
case ((word, features), index) =>
IndexedRow(index, Vectors.dense(features))
}
// store index for later re-mapping (index to word)
val index = indexed.map {
case ((word, features), index) =>
(index, word)
}.persist(storageLevel)
// create an input matrix from all rows and run lsh on it
val matrix = new IndexedRowMatrix(rows)
val lsh = new Lsh(
minCosineSimilarity = 0.5,
dimensions = 20,
numNeighbours = 200,
numPermutations = 10,
partitions = 200,
storageLevel = storageLevel
)
val similarityMatrix = lsh.join(matrix)
import sc.implicits._
// remap both ids back to words
val remapFirst = similarityMatrix.entries.keyBy(_.i).join(index).values
val remapSecond = remapFirst.keyBy { case (entry, word1) => entry.j }.join(index).values.map {
case ((entry, word1), word2) =>
(word1, word2, entry.value)
}
remapSecond.take(20).foreach(println)
val score_result = remapSecond.toDF("cid1","cid2","score")
GmeiConfig.writeToJDBCTable(score_result, table="cid_pairs_cosine_distince", SaveMode.Overwrite)
// group by neighbours to get a list of similar words and then take top k
val result = remapSecond.groupBy(_._1).map {
case (word1, similarWords) =>
// sort by score desc. and take top 10 entries
val similar = similarWords.toSeq.sortBy(-1 * _._3).filter(_._2.startsWith("diary")).take(10).map(_._2).mkString(",")
(word1,s"$similar")
}
// print out the results for the first 10 words
result.take(20).foreach(println)
val similar_result = result.toDF("cid","similarity_cid")
GmeiConfig.writeToJDBCTable(similar_result, table="cid_similarity_matrix", SaveMode.Overwrite)
}
} getOrElse {
sys.exit(1)
}
}
package com.gmei
import java.io.Serializable
import scala.util.Try
import scala.collection.mutable.ArrayBuffer
import org.slf4j.{Logger, LoggerFactory}
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.graphx.{EdgeTriplet, Graph, _}
import com.gmei.graph.{EdgeAttr, GraphOps, NodeAttr}
import org.apache.spark.sql.DataFrame
object Node2vec extends Serializable {
lazy val logger: Logger = LoggerFactory.getLogger(getClass.getName);
var context: SparkContext = null
var config: Main.Params = null
var node2id: RDD[(String, Long)] = null
var indexedEdges: RDD[Edge[EdgeAttr]] = _
var indexedNodes: RDD[(VertexId, NodeAttr)] = _
var graph: Graph[NodeAttr, EdgeAttr] = _
var randomWalkPaths: RDD[(Long, ArrayBuffer[Long])] = null
def setup(context: SparkContext, param: Main.Params): this.type = {
this.context = context
this.config = param
this
}
def load(tidb_input: DataFrame): this.type = {
val bcMaxDegree = context.broadcast(config.degree)
val bcEdgeCreator = config.directed match {
case true => context.broadcast(GraphOps.createDirectedEdge)
case false => context.broadcast(GraphOps.createUndirectedEdge)
}
val inputTriplets: RDD[(Long, Long, Double)] = config.indexed match {
// case true => readIndexedGraph(config.input)
case false => indexingGraph(tidb_input)
}
indexedNodes = inputTriplets.flatMap { case (srcId, dstId, weight) =>
bcEdgeCreator.value.apply(srcId, dstId, weight)
}.reduceByKey(_++_).map { case (nodeId, neighbors: Array[(VertexId, Double)]) =>
var neighbors_ = neighbors
if (neighbors_.length > bcMaxDegree.value) {
neighbors_ = neighbors.sortWith{ case (left, right) => left._2 > right._2 }.slice(0, bcMaxDegree.value)
}
(nodeId, NodeAttr(neighbors = neighbors_.distinct))
}.repartition(200).cache
indexedEdges = indexedNodes.flatMap { case (srcId, clickNode) =>
clickNode.neighbors.map { case (dstId, weight) =>
Edge(srcId, dstId, EdgeAttr())
}
}.repartition(200).cache
this
}
def initTransitionProb(): this.type = {
val bcP = context.broadcast(config.p)
val bcQ = context.broadcast(config.q)
graph = Graph(indexedNodes, indexedEdges)
.mapVertices[NodeAttr] { case (vertexId, clickNode) =>
val (j, q) = GraphOps.setupAlias(clickNode.neighbors)
val nextNodeIndex = GraphOps.drawAlias(j, q)
clickNode.path = Array(vertexId, clickNode.neighbors(nextNodeIndex)._1)
clickNode
}
.mapTriplets { edgeTriplet: EdgeTriplet[NodeAttr, EdgeAttr] =>
val (j, q) = GraphOps.setupEdgeAlias(bcP.value, bcQ.value)(edgeTriplet.srcId, edgeTriplet.srcAttr.neighbors, edgeTriplet.dstAttr.neighbors)
edgeTriplet.attr.J = j
edgeTriplet.attr.q = q
edgeTriplet.attr.dstNeighbors = edgeTriplet.dstAttr.neighbors.map(_._1)
edgeTriplet.attr
}.cache
this
}
def randomWalk(): this.type = {
val edge2attr = graph.triplets.map { edgeTriplet =>
(s"${edgeTriplet.srcId}${edgeTriplet.dstId}", edgeTriplet.attr)
}.repartition(200).cache
edge2attr.first
for (iter <- 0 until config.numWalks) {
var prevWalk: RDD[(Long, ArrayBuffer[Long])] = null
var randomWalk = graph.vertices.map { case (nodeId, clickNode) =>
val pathBuffer = new ArrayBuffer[Long]()
pathBuffer.append(clickNode.path:_*)
(nodeId, pathBuffer)
}.cache
var activeWalks = randomWalk.first
graph.unpersist(blocking = false)
graph.edges.unpersist(blocking = false)
for (walkCount <- 0 until config.walkLength) {
prevWalk = randomWalk
randomWalk = randomWalk.map { case (srcNodeId, pathBuffer) =>
val prevNodeId = pathBuffer(pathBuffer.length - 2)
val currentNodeId = pathBuffer.last
(s"$prevNodeId$currentNodeId", (srcNodeId, pathBuffer))
}.join(edge2attr).map { case (edge, ((srcNodeId, pathBuffer), attr)) =>
try {
val nextNodeIndex = GraphOps.drawAlias(attr.J, attr.q)
val nextNodeId = attr.dstNeighbors(nextNodeIndex)
pathBuffer.append(nextNodeId)
(srcNodeId, pathBuffer)
} catch {
case e: Exception => throw new RuntimeException(e.getMessage)
}
}.cache
activeWalks = randomWalk.first()
prevWalk.unpersist(blocking=false)
}
if (randomWalkPaths != null) {
val prevRandomWalkPaths = randomWalkPaths
randomWalkPaths = randomWalkPaths.union(randomWalk).cache()
randomWalkPaths.first
prevRandomWalkPaths.unpersist(blocking = false)
} else {
randomWalkPaths = randomWalk
}
}
this
}
def embedding(): this.type = {
val randomPaths = randomWalkPaths.map { case (vertexId, pathBuffer) =>
Try(pathBuffer.map(_.toString).toIterable).getOrElse(null)
}.filter(_!=null)
Word2vec.setup(context, config).fit(randomPaths)
this
}
// def save(): this.type = {
// this.saveRandomPath()
// .saveModel()
// .saveVectors()
// }
// def saveRandomPath(): this.type = {
// randomWalkPaths
// .map { case (vertexId, pathBuffer) =>
// Try(pathBuffer.mkString("\t")).getOrElse(null)
// }
// .filter(x => x != null && x.replaceAll("\\s", "").length > 0)
// .repartition(200)
// .saveAsTextFile(config.output)
//
// this
// }
//
// def saveModel(): this.type = {
// Word2vec.save(config.output)
//
// this
// }
//
// def saveVectors(): this.type = {
// val node2vector = context.parallelize(Word2vec.getVectors.toList)
// .map { case (nodeId, vector) =>
// (nodeId.toLong, vector.mkString(","))
// }
//
// if (this.node2id != null) {
// val id2Node = this.node2id.map{ case (strNode, index) =>
// (index, strNode)
// }
//
// node2vector.join(id2Node)
// .map { case (nodeId, (vector, name)) => s"$name\t$vector" }
// .repartition(200)
// .saveAsTextFile(s"${config.output}.emb")
// } else {
// node2vector.map { case (nodeId, vector) => s"$nodeId\t$vector" }
// .repartition(200)
// .saveAsTextFile(s"${config.output}.emb")
// }
//
// this
// }
//
def cleanup(): this.type = {
node2id.unpersist(blocking = false)
indexedEdges.unpersist(blocking = false)
indexedNodes.unpersist(blocking = false)
graph.unpersist(blocking = false)
randomWalkPaths.unpersist(blocking = false)
this
}
def loadNode2Id(node2idPath: String): this.type = {
try {
this.node2id = context.textFile(config.nodePath).map { node2index =>
val Array(strNode, index) = node2index.split("\\s")
(strNode, index.toLong)
}
} catch {
case e: Exception => logger.info("Failed to read node2index file.")
this.node2id = null
}
this
}
// def readIndexedGraph(tripletPath: String) = {
// val bcWeighted = context.broadcast(config.weighted)
//
// val rawTriplets = context.textFile(tripletPath)
// if (config.nodePath == null) {
// this.node2id = createNode2Id(rawTriplets.map { triplet =>
// val parts = triplet.split("\\s")
// (parts.head, parts(1), -1)
// })
// } else {
// loadNode2Id(config.nodePath)
// }
//
// rawTriplets.map { triplet =>
// val parts = triplet.split("\\s")
// val weight = bcWeighted.value match {
// case true => Try(parts.last.toDouble).getOrElse(1.0)
// case false => 1.0
// }
//
// (parts.head.toLong, parts(1).toLong, weight)
// }
// }
def indexingGraph(tidb_input: DataFrame): RDD[(Long, Long, Double)] = {
val rawEdges = tidb_input.rdd.map { triplet =>
val parts = (triplet.getAs[String]("service_id"), triplet.getAs[String]("cid"))
Try {
(parts._1, parts._2, Try(parts._2.toDouble).getOrElse(1.0))
}.getOrElse(null)
}.filter(_!=null)
this.node2id = createNode2Id(rawEdges)
rawEdges.map { case (src, dst, weight) =>
(src, (dst, weight))
}.join(node2id).map { case (src, (edge: (String, Double), srcIndex: Long)) =>
try {
val (dst: String, weight: Double) = edge
(dst, (srcIndex, weight))
} catch {
case e: Exception => null
}
}.filter(_!=null).join(node2id).map { case (dst, (edge: (Long, Double), dstIndex: Long)) =>
try {
val (srcIndex, weight) = edge
(srcIndex, dstIndex, weight)
} catch {
case e: Exception => null
}
}.filter(_!=null)
}
def createNode2Id[T <: Any](triplets: RDD[(String, String, T)]) = triplets.flatMap { case (src, dst, weight) =>
Try(Array(src, dst)).getOrElse(Array.empty[String])
}.distinct().zipWithIndex()
}
package com.gmei
import org.apache.spark.SparkContext
import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel}
import org.apache.spark.rdd.RDD
object Word2vec extends Serializable {
var context: SparkContext = null
var word2vec = new Word2Vec()
var model: Word2VecModel = null
def setup(context: SparkContext, param: Main.Params): this.type = {
this.context = context
/**
* model = sg
* update = hs
*/
word2vec.setLearningRate(param.lr)
.setNumIterations(param.iter)
.setNumPartitions(param.numPartition)
.setMinCount(0)
.setVectorSize(param.dim)
val word2vecWindowField = word2vec.getClass.getDeclaredField("org$apache$spark$mllib$feature$Word2Vec$$window")
word2vecWindowField.setAccessible(true)
word2vecWindowField.setInt(word2vec, param.window)
this
}
def read(path: String): RDD[Iterable[String]] = {
context.textFile(path).repartition(200).map(_.split("\\s").toSeq)
}
def fit(input: RDD[Iterable[String]]): this.type = {
model = word2vec.fit(input)
this
}
def save(outputPath: String): this.type = {
model.save(context, s"$outputPath.bin")
this
}
def load(path: String): this.type = {
model = Word2VecModel.load(context, path)
this
}
def getVectors = this.model.getVectors
}
package com.gmei.graph
import scala.collection.mutable.ArrayBuffer
object GraphOps {
def setupAlias(nodeWeights: Array[(Long, Double)]): (Array[Int], Array[Double]) = {
val K = nodeWeights.length
val J = Array.fill(K)(0)
val q = Array.fill(K)(0.0)
val smaller = new ArrayBuffer[Int]()
val larger = new ArrayBuffer[Int]()
val sum = nodeWeights.map(_._2).sum
nodeWeights.zipWithIndex.foreach { case ((nodeId, weight), i) =>
q(i) = K * weight / sum
if (q(i) < 1.0) {
smaller.append(i)
} else {
larger.append(i)
}
}
while (smaller.nonEmpty && larger.nonEmpty) {
val small = smaller.remove(smaller.length - 1)
val large = larger.remove(larger.length - 1)
J(small) = large
q(large) = q(large) + q(small) - 1.0
if (q(large) < 1.0) smaller.append(large)
else larger.append(large)
}
(J, q)
}
def setupEdgeAlias(p: Double = 1.0, q: Double = 1.0)(srcId: Long, srcNeighbors: Array[(Long, Double)], dstNeighbors: Array[(Long, Double)]): (Array[Int], Array[Double]) = {
val neighbors_ = dstNeighbors.map { case (dstNeighborId, weight) =>
var unnormProb = weight / q
if (srcId == dstNeighborId) unnormProb = weight / p
else if (srcNeighbors.exists(_._1 == dstNeighborId)) unnormProb = weight
(dstNeighborId, unnormProb)
}
setupAlias(neighbors_)
}
def drawAlias(J: Array[Int], q: Array[Double]): Int = {
val K = J.length
val kk = math.floor(math.random * K).toInt
if (math.random < q(kk)) kk
else J(kk)
}
lazy val createUndirectedEdge = (srcId: Long, dstId: Long, weight: Double) => {
Array(
(srcId, Array((dstId, weight))),
(dstId, Array((srcId, weight)))
)
}
lazy val createDirectedEdge = (srcId: Long, dstId: Long, weight: Double) => {
Array(
(srcId, Array((dstId, weight)))
)
}
}
package com.gmei
import java.io.Serializable
package object graph {
case class NodeAttr(var neighbors: Array[(Long, Double)] = Array.empty[(Long, Double)],
var path: Array[Long] = Array.empty[Long]) extends Serializable
case class EdgeAttr(var dstNeighbors: Array[Long] = Array.empty[Long],
var J: Array[Int] = Array.empty[Int],
var q: Array[Double] = Array.empty[Double]) extends Serializable
}
package com.gmei.lib
import scala.reflect.runtime.universe._
/**
* Abstract class for parameter case classes.
* This overrides the [[toString]] method to print all case class fields by name and value.
* @tparam T Concrete parameter class.
*/
abstract class AbstractParams[T: TypeTag] {
private def tag: TypeTag[T] = typeTag[T]
/**
* Finds all case class fields in concrete class instance, and outputs them in JSON-style format:
* {
* [field name]:\t[field value]\n
* [field name]:\t[field value]\n
* ...
* }
*/
override def toString: String = {
val tpe = tag.tpe
val allAccessors = tpe.decls.collect {
case m: MethodSymbol if m.isCaseAccessor => m
}
val mirror = runtimeMirror(getClass.getClassLoader)
val instanceMirror = mirror.reflect(this)
allAccessors.map { f =>
val paramName = f.name.toString
val fieldMirror = instanceMirror.reflectField(f)
val paramValue = fieldMirror.get
s" $paramName:\t$paramValue"
}.mkString("{\n", ",\n", "\n}")
}
}
\ No newline at end of file
......@@ -8,7 +8,7 @@ def con_sql(sql):
:type sql : str
:rtype : tuple
"""
db = pymysql.connect(host='10.66.157.22', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
db = pymysql.connect(host='10.66.157.22', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_prod')
cursor = db.cursor()
cursor.execute(sql)
result = cursor.fetchall()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment