Commit 4bc12115 authored by 高雅喆's avatar 高雅喆

add node2vec project

parent 3f7e90a3
*.class
*.log
build.sbt_back
project/
# sbt specific
dist/*
target/
lib_managed/
src_managed/
project/boot/
project/plugins/project/
sbt/*.jar
mini-complete-example/sbt/*.jar
spark-warehouse/
# Scala-IDE specific
.scala_dependencies
#Emacs
*~
#ignore the metastore
metastore_db/*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
# C extensions
*.so
# Distribution / packaging
.env
.Python
env/bin/
build/*.jar
develop-eggs/
dist/
eggs/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.cache
nosetests.xml
coverage.xml
# Translations
*.mo
# Mr Developer
.mr.developer.cfg
.project
.pydevproject
# Rope
.ropeproject
# Django stuff:
*.log
*.pot
# Sphinx documentation
docs/_build/
# PyCharm files
*.idea
# emacs stuff
# Autoenv
.env
*~
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
# C extensions
*.so
# Distribution / packaging
.env
.Python
env/
bin/
build/
develop-eggs/
dist/
eggs/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.cache
nosetests.xml
coverage.xml
# Translations
*.mo
# Mr Developer
.mr.developer.cfg
.project
.pydevproject
# Rope
.ropeproject
# Django stuff:
*.log
*.pot
# Sphinx documentation
docs/_build/
# PyCharm files
*.idea
# emacs stuff
\#*\#
\.\#*
# Autoenv
.env
*~
Copyright 2014 Typesafe, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
## spark-hadoop jar包
http://mirror.bit.edu.cn/apache/spark/spark-2.2.1/spark-2.2.1-bin-hadoop2.7.tgz
## tispark jar包驱动
http://download.pingcap.org/tispark-1.0-RC1-jar-with-dependencies.jar
## jdk >= 1.8.0
http://www.oracle.com/technetwork/java/javase/downloads/jdk8-downloads-2133151.html
## mysql 驱动
http://central.maven.org/maven2/mysql/mysql-connector-java/5.1.44/mysql-connector-java-5.1.44.jar
## 放置jar包
mysql jar
tispark jar
/usr/local/Cellar/apache-spark/2.1.2/jars
## spark 配置
/usr/local/Cellar/apache-spark/2.1.2/conf
spark-env.sh
export SPARK_EXECUTOR_MEMORY=3g
export SPARK_WORKER_MEMORY=3g
export SPARK_WORKER_CORES=2
spark-defaults.conf
spark.tispark.pd.addresses 192.168.15.11:2379
name := """Node2vec"""
lazy val commonSettings = Seq(
version := "0.2",
organization := "com.gmei",
scalaVersion := "2.11.8",
test in assembly := {}
)
autoScalaLibrary := false
val sparkVersion = "2.2.1"
libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-core" % sparkVersion,
"org.apache.spark" %% "spark-sql" % sparkVersion,
"org.apache.spark" %% "spark-hive" % sparkVersion,
"org.apache.spark" %% "spark-streaming" % sparkVersion,
"org.apache.spark" %% "spark-streaming-kafka-0-10" % sparkVersion,
"org.apache.spark" %% "spark-mllib" % sparkVersion,
"mysql" % "mysql-connector-java" % "5.1.38",
"com.typesafe" % "config" % "1.3.2",
"org.apache.logging.log4j" % "log4j-scala" % "11.0" pomOnly(),
"org.scalatest" %% "scalatest" % "3.0.5" % "test",
"com.github.nscala-time" %% "nscala-time" % "2.18.0",
"com.github.scopt" %% "scopt" % "3.7.0",
"com.google.guava" % "guava" % "19.0",
"com.github.fommil.netlib" % "all" % "1.1.2"
)
lazy val root = (project in file(".")).settings(commonSettings: _*)
assemblyMergeStrategy in assembly := {
case PathList("META-INF", xs @ _*) => MergeStrategy.discard
case x => MergeStrategy.first
}
#### source table naming rules
src_$dbname_$tablenmae, eg
src_mimas_prod_api_preoperationimage
#### etl table naming rules
mid_$tablename, eg
mid_diary_tags
#### theme table naming rules
biz_$tablename, eg
biz_user_index_diary_list_ctr
\ No newline at end of file
dev.tidb.jdbcuri=jdbc:mysql://192.168.15.12:4000/jerry_test?user=root&password=&rewriteBatchedStatements=true
dev.tispark.pd.addresses=192.168.15.11:2379
dev.mimas.jdbcuri= jdbc:mysql://rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com/mimas_test?user=work&password=workwork&rewriteBatchedStatements=true
dev.gaia.jdbcuri=jdbc:mysql://rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com/zhengxing_test?user=work&password=workwork&rewriteBatchedStatements=true
dev.gold.jdbcuri=jdbc:mysql://rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com/doris_test?user=work&password=workwork&rewriteBatchedStatements=true
pre.tidb.jdbcuri=jdbc:mysql://192.168.16.11:4000/eagle?user=root&password=&rewriteBatchedStatements=true
pre.tispark.pd.addresses=192.168.16.11:2379
pre.mimas.jdbcuri=jdbc:mysql://rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com:3308/mimas_prod?user=mimas&password=workwork&rewriteBatchedStatements=true
prod.tidb.jdbcuri=jdbc:mysql://10.66.157.22:4000/jerry_test?user=root&password=3SYz54LS9#^9sBvC&rewriteBatchedStatements=true
prod.gold.jdbcuri=jdbc:mysql://rm-m5e842126ng59jrv6.mysql.rds.aliyuncs.com/doris_prod?user=doris&password=o5gbA27hXHHm&rewriteBatchedStatements=true
prod.mimas.jdbcuri=jdbc:mysql://rm-m5emg41za2w7l6au3.mysql.rds.aliyuncs.com/mimas_prod?user=mimas&password=GJL3UJe1Ck9ggL6aKnZCq4cRvM&rewriteBatchedStatements=true
prod.gaia.jdbcuri=jdbc:mysql://rdsfewzdmf0jfjp9un8xj.mysql.rds.aliyuncs.com/zhengxing?user=work&password=BJQaT9VzDcuPBqkd&rewriteBatchedStatements=true
prod.tispark.pd.addresses=10.66.157.22:2379
appender.out.type = Console
appender.out.name = out
appender.out.layout.type = PatternLayout
appender.out.layout.pattern = [%30.30t] %-30.30c{1} %-5p %m%n
logger.springframework.name = org.springframework
logger.springframework.level = WARN
rootLogger.level = INFO
rootLogger.appenderRef.out.ref = out
package com.gmei
import java.util.Arrays
import scala.util.hashing.MurmurHash3
/**
* A simple, fixed-size bit set implementation. This implementation is fast because it avoids
* safety/bound checking.
*/
class BitSet(val numBits: Int) extends Serializable {
private val words = new Array[Long](bit2words(numBits))
private val numWords = words.length
/**
* Compute the capacity (number of bits) that can be represented
* by this bitset.
*/
def capacity: Int = numWords * 64
/**
* Clear all set bits.
*/
def clear(): Unit = Arrays.fill(words, 0)
/**
* Set all the bits up to a given index
*/
def setUntil(bitIndex: Int): Unit = {
val wordIndex = bitIndex >> 6 // divide by 64
Arrays.fill(words, 0, wordIndex, -1)
if(wordIndex < words.length) {
// Set the remaining bits (note that the mask could still be zero)
val mask = ~(-1L << (bitIndex & 0x3f))
words(wordIndex) |= mask
}
}
/**
* Clear all the bits up to a given index
*/
def clearUntil(bitIndex: Int): Unit = {
val wordIndex = bitIndex >> 6 // divide by 64
Arrays.fill(words, 0, wordIndex, 0)
if(wordIndex < words.length) {
// Clear the remaining bits
val mask = -1L << (bitIndex & 0x3f)
words(wordIndex) &= mask
}
}
/**
* Compute the bit-wise AND of the two sets returning the
* result.
*/
def &(other: BitSet): BitSet = {
val newBS = new BitSet(math.max(capacity, other.capacity))
val smaller = math.min(numWords, other.numWords)
assert(newBS.numWords >= numWords)
assert(newBS.numWords >= other.numWords)
var ind = 0
while( ind < smaller ) {
newBS.words(ind) = words(ind) & other.words(ind)
ind += 1
}
newBS
}
/**
* Compute the bit-wise OR of the two sets returning the
* result.
*/
def |(other: BitSet): BitSet = {
val newBS = new BitSet(math.max(capacity, other.capacity))
assert(newBS.numWords >= numWords)
assert(newBS.numWords >= other.numWords)
val smaller = math.min(numWords, other.numWords)
var ind = 0
while( ind < smaller ) {
newBS.words(ind) = words(ind) | other.words(ind)
ind += 1
}
while( ind < numWords ) {
newBS.words(ind) = words(ind)
ind += 1
}
while( ind < other.numWords ) {
newBS.words(ind) = other.words(ind)
ind += 1
}
newBS
}
/**
* Compute the symmetric difference by performing bit-wise XOR of the two sets returning the
* result.
*/
def ^(other: BitSet): BitSet = {
val newBS = new BitSet(math.max(capacity, other.capacity))
val smaller = math.min(numWords, other.numWords)
var ind = 0
while (ind < smaller) {
newBS.words(ind) = words(ind) ^ other.words(ind)
ind += 1
}
if (ind < numWords) {
Array.copy( words, ind, newBS.words, ind, numWords - ind )
}
if (ind < other.numWords) {
Array.copy( other.words, ind, newBS.words, ind, other.numWords - ind )
}
newBS
}
/**
* Compute the difference of the two sets by performing bit-wise AND-NOT returning the
* result.
*/
def andNot(other: BitSet): BitSet = {
val newBS = new BitSet(capacity)
val smaller = math.min(numWords, other.numWords)
var ind = 0
while (ind < smaller) {
newBS.words(ind) = words(ind) & ~other.words(ind)
ind += 1
}
if (ind < numWords) {
Array.copy( words, ind, newBS.words, ind, numWords - ind )
}
newBS
}
/**
* Sets the bit at the specified index to true.
* @param index the bit index
*/
def set(index: Int) {
val bitmask = 1L << (index & 0x3f) // mod 64 and shift
words(index >> 6) |= bitmask // div by 64 and mask
}
def unset(index: Int) {
val bitmask = 1L << (index & 0x3f) // mod 64 and shift
words(index >> 6) &= ~bitmask // div by 64 and mask
}
def flip(index: Int) {
val bitmask = 1L << (index & 0x3f)
words(index >> 6) ^= bitmask
}
/**
* Return the value of the bit with the specified index. The value is true if the bit with
* the index is currently set in this BitSet; otherwise, the result is false.
*
* @param index the bit index
* @return the value of the bit with the specified index
*/
def get(index: Int): Boolean = {
val bitmask = 1L << (index & 0x3f) // mod 64 and shift
(words(index >> 6) & bitmask) != 0 // div by 64 and mask
}
/**
* Get an iterator over the set bits.
*/
def iterator: Iterator[Int] = new Iterator[Int] {
var ind = nextSetBit(0)
override def hasNext: Boolean = ind >= 0
override def next(): Int = {
val tmp = ind
ind = nextSetBit(ind + 1)
tmp
}
}
/** Return the number of bits set to true in this BitSet. */
def cardinality(): Int = {
var sum = 0
var i = 0
while (i < numWords) {
sum += java.lang.Long.bitCount(words(i))
i += 1
}
sum
}
/**
* Returns the index of the first bit that is set to true that occurs on or after the
* specified starting index. If no such bit exists then -1 is returned.
*
* To iterate over the true bits in a BitSet, use the following loop:
*
* for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i+1)) {
* // operate on index i here
* }
*
* @param fromIndex the index to start checking from (inclusive)
* @return the index of the next set bit, or -1 if there is no such bit
*/
def nextSetBit(fromIndex: Int): Int = {
var wordIndex = fromIndex >> 6
if (wordIndex >= numWords) {
return -1
}
// Try to find the next set bit in the current word
val subIndex = fromIndex & 0x3f
var word = words(wordIndex) >> subIndex
if (word != 0) {
return (wordIndex << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word)
}
// Find the next set bit in the rest of the words
wordIndex += 1
while (wordIndex < numWords) {
word = words(wordIndex)
if (word != 0) {
return (wordIndex << 6) + java.lang.Long.numberOfTrailingZeros(word)
}
wordIndex += 1
}
-1
}
/** Return the number of longs it would take to hold numBits. */
private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1
/** Override hashCode method to allow BitSet be used as an RDD key */
override def hashCode() = MurmurHash3.arrayHash(words)
}
object BitSet {
def apply(bitSet: BitSet) = {
val bitSetCopy = new BitSet(bitSet.numBits)
bitSet.iterator.foreach(ix => bitSetCopy.set(ix))
bitSetCopy
}
}
\ No newline at end of file
package com.gmei
object ENV extends Enumeration {
type ENV = String
val PROD = "prod"
val DEV = "dev"
val PRE = "pre"
}
package com.gmei
import java.util.Properties
import java.io.Serializable
import com.typesafe.config._
import org.apache.spark.{SparkConf,SparkContext}
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession, TiContext}
import com.gmei.ENV.ENV
object GmeiConfig extends Serializable {
var param1: Main.Params = null
var env: String = null
var config: Config = null
def setup(param: Main.Params): this.type = {
this.param1 = param
this.env = this.param1.env match {
case "prod" => ENV.PROD
case "dev" => ENV.DEV
case "pre" => ENV.PRE
case _ => ENV.DEV
}
this.config = initConfig(this.env)
this
}
//这种情况下的param1的值为null,所以会出现空指针异常的错误
// val env = param1.env match {
// case "prod" => ENV.PROD
// case "dev" => ENV.DEV
// case "pre" => ENV.PRE
// case _ => ENV.DEV
// }
// val config = initConfig(env)
def initConfig(env: ENV) = {
lazy val c = ConfigFactory.load()
c.getConfig(env).withFallback(c)
}
def getSparkSession():(SparkContext, SparkSession) = {
val sparkConf = new SparkConf
sparkConf.set("spark.sql.crossJoin.enabled", "true")
if (!sparkConf.contains("spark.master")) {
sparkConf.setMaster("local[3]")
}
if (!sparkConf.contains("spark.tispark.pd.addresses")) {
sparkConf.set("spark.tispark.pd.addresses", config.getString("tispark.pd.addresses"))
}
println(sparkConf.get("spark.tispark.pd.addresses"))
val spark = SparkSession
.builder()
.config(sparkConf)
.appName("node2vec")
.getOrCreate()
val context = SparkContext.getOrCreate(sparkConf)
val ti = new TiContext(spark)
// mapping all tables cache it.
ti.tidbMapDatabase("jerry_test")
(context, spark)
}
def writeToJDBCTable(jdbcuri: String, df: DataFrame, table: String, saveModel: SaveMode): Unit = {
println(jdbcuri, table)
val prop = new Properties()
prop.put("driver", "com.mysql.jdbc.Driver")
prop.put("useSSL", "false")
prop.put("isolationLevel", "NONE")
prop.put("isTruncate", "true")
// save to mysql/tidb
df.repartition(128).write.mode(saveModel)
.option(JDBCOptions.JDBC_BATCH_INSERT_SIZE, 300)
.jdbc(jdbcuri, table, prop)
}
def writeToJDBCTable(df: DataFrame, table: String, saveMode: SaveMode): Unit = {
val jdbcuri = config.getString("tidb.jdbcuri")
println(jdbcuri, table)
writeToJDBCTable(jdbcuri, df, table, saveMode)
}
}
package com.gmei
import org.apache.spark.mllib.linalg.distributed.{CoordinateMatrix, IndexedRowMatrix}
trait Joiner {
/**
* Find the k nearest neighbors from a data set for every other object in the
* same data set. Implementations may be either exact or approximate.
*
* @param matrix a row oriented matrix. Each row in the matrix represents
* an item in the data set. Items are identified by their
* matrix index.
* @return a similarity matrix with MatrixEntry(itemA, itemB, similarity).
*
*/
def join(matrix: IndexedRowMatrix): CoordinateMatrix
}
trait QueryJoiner {
/**
* Find the k nearest neighbours in catalogMatrix for each entry in queryMatrix.
* Implementations may be either exact or approximate.
*
* @param queryMatrix a row oriented matrix. Each row in the matrix represents
* an item in the data set. Items are identified by their
* matrix index.
* @param catalogMatrix a row oriented matrix. Each row in the matrix represents
* an item in the data set. Items are identified by their
* matrix index.
* @return a similarity matrix with MatrixEntry(queryA, catalogB, similarity).
*/
def join(queryMatrix: IndexedRowMatrix, catalogMatrix: IndexedRowMatrix): CoordinateMatrix
}
package com.gmei
import org.apache.spark.mllib.linalg.distributed._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import com.gmei.lsh.{localRandomMatrix,matrixToBitSet,distinct,Signature}
/**
* Lsh implementation as described in 'Randomized Algorithms and NLP: Using
* Locality Sensitive Hash Function for High Speed Noun Clustering' by
* Ravichandran et al. See original publication for a detailed description of
* the parameters.
*
* @see http://dl.acm.org/citation.cfm?id=1219917
* @param minCosineSimilarity minimum similarity two items need to have
* otherwise they are discarded from the result set
* @param dimensions number of random vectors (hyperplanes) to generate bit
* vectors of length d
* @param numNeighbours beam factor e.g. how many neighbours are considered
* in the sliding window
* @param numPermutations number of times bitsets are permuted
*
*/
class Lsh(minCosineSimilarity: Double,
dimensions: Int,
numNeighbours: Int,
numPermutations: Int,
partitions: Int = 200,
storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK)
extends Joiner with Serializable {
override def join(inputMatrix: IndexedRowMatrix): CoordinateMatrix = {
val numFeatures = inputMatrix.numCols().toInt
val randomMatrix = localRandomMatrix(dimensions, numFeatures)
val signatures = matrixToBitSet(inputMatrix, randomMatrix).
repartition(partitions).
persist(storageLevel)
val neighbours = 0 until numPermutations map {
x =>
val permutation = generatePermutation(dimensions)
val permuted = permuteBitSet(signatures, permutation, dimensions)
val orderedSignatures = orderByBitSet(permuted)
val slidingWindow = createSlidingWindow(orderedSignatures, numNeighbours)
findNeighbours(slidingWindow, minCosineSimilarity)
}
val mergedNeighbours = neighbours.reduce(_ ++ _)
new CoordinateMatrix(distinct(mergedNeighbours))
}
/**
* Permutes a signatures by a given permutation
*
*/
def permuteBitSet(signatures: RDD[Signature], permutation: Iterable[Int], d: Int): RDD[Signature] = {
signatures.map {
signature: Signature =>
val permutedBitSet = permuteBitSet(signature.bitSet, permutation, d)
signature.copy(bitSet = permutedBitSet)
}
}
/**
* Permutes a bit set representation of a vector by a given permutation
*/
def permuteBitSet(bitSet: BitSet, permutation: Iterable[Int], d: Int): BitSet = {
val permutationWithIndex = permutation.zipWithIndex
val newBitSet = new BitSet(d)
permutationWithIndex.foreach {
case ((newIndex: Int, oldIndex: Int)) =>
val oldBit = bitSet.get(oldIndex)
if (oldBit)
newBitSet.set(newIndex)
else
newBitSet.unset(newIndex)
}
newBitSet
}
/**
* Generates a random permutation of size n
*/
def generatePermutation(size: Int): Iterable[Int] = {
val indices = (0 until size).toArray
util.Random.shuffle(indices)
}
/**
* Orderes an RDD of signatures by their bit set representation
*/
def orderByBitSet(signatures: RDD[Signature]): RDD[Signature] = {
signatures.sortBy(identity)
}
/**
* Creates a sliding window
*
*/
def createSlidingWindow(signatures: RDD[Signature], b: Int): RDD[Array[Signature]] = {
new SlidingRDD[Signature](signatures, b, b)
}
def findNeighbours(signatures: RDD[Array[Signature]], minCosineSimilarity: Double): RDD[MatrixEntry] = {
signatures.flatMap { signature: Array[Signature] =>
neighbours(signature, minCosineSimilarity)
}
}
/**
* Generate all pairs and emit if cosine of pair > minCosineSimilarity
*
*/
def neighbours(signatures: Array[Signature], minCosineSimilarity: Double): Iterator[MatrixEntry] = {
signatures.
sortBy(_.index). // sort in order to create an upper triangular matrix
combinations(2).
map {
case Array(first, second) =>
val cosine = Cosine(first.vector, second.vector)
MatrixEntry(first.index, second.index, cosine)
}.
filter(_.value >= minCosineSimilarity)
}
}
package com.gmei
import java.io.Serializable
import org.apache.spark.storage.StorageLevel
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.distributed.{IndexedRow, IndexedRowMatrix}
import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{SaveMode,TiContext}
import org.apache.log4j.{Level, Logger}
import scopt.OptionParser
import com.gmei.lib.AbstractParams
object Main {
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)
case class Params(iter: Int = 10,
lr: Double = 0.025,
numPartition: Int = 10,
dim: Int = 128,
window: Int = 10,
walkLength: Int = 80,
numWalks: Int = 10,
p: Double = 1.0,
q: Double = 1.0,
weighted: Boolean = true,
directed: Boolean = false,
degree: Int = 30,
indexed: Boolean = false,
env: String = ENV.DEV,
nodePath: String = null
) extends AbstractParams[Params] with Serializable
val defaultParams = Params()
val parser = new OptionParser[Params]("Node2Vec_Spark") {
head("Main")
opt[Int]("walkLength")
.text(s"walkLength: ${defaultParams.walkLength}")
.action((x, c) => c.copy(walkLength = x))
opt[Int]("numWalks")
.text(s"numWalks: ${defaultParams.numWalks}")
.action((x, c) => c.copy(numWalks = x))
opt[Double]("p")
.text(s"return parameter p: ${defaultParams.p}")
.action((x, c) => c.copy(p = x))
opt[Double]("q")
.text(s"in-out parameter q: ${defaultParams.q}")
.action((x, c) => c.copy(q = x))
opt[Boolean]("weighted")
.text(s"weighted: ${defaultParams.weighted}")
.action((x, c) => c.copy(weighted = x))
opt[Boolean]("directed")
.text(s"directed: ${defaultParams.directed}")
.action((x, c) => c.copy(directed = x))
opt[Int]("degree")
.text(s"degree: ${defaultParams.degree}")
.action((x, c) => c.copy(degree = x))
opt[Boolean]("indexed")
.text(s"Whether nodes are indexed or not: ${defaultParams.indexed}")
.action((x, c) => c.copy(indexed = x))
opt[String]("env")
.text(s"the databases environment you used")
.action((x,c) => c.copy(env = x))
opt[String]("nodePath")
.text("Input node2index file path: empty")
.action((x, c) => c.copy(nodePath = x))
note(
"""
|For example, the following command runs this app on a synthetic dataset:
|
| bin/spark-submit --class com.nhn.sunny.vegapunk.ml.model.Node2vec \
""".stripMargin +
s"| --lr ${defaultParams.lr}" +
s"| --iter ${defaultParams.iter}" +
s"| --numPartition ${defaultParams.numPartition}" +
s"| --dim ${defaultParams.dim}" +
s"| --window ${defaultParams.window}" +
s"| --node <nodeFilePath>" +
s"| --output <path>"
)
}
def main(args: Array[String]):Unit = {
parser.parse(args, defaultParams).map { param =>
//1. get the input and node2vec
GmeiConfig.setup(param)
val context = GmeiConfig.getSparkSession()._1
val sc = GmeiConfig.getSparkSession()._2
val ti = new TiContext(sc)
ti.tidbMapTable(dbName = "jerry_test",tableName = "data_diary_click")
val tidb_inupt = sc.sql(
s"""
|SELECT
| service_id,cid
|FROM data_diary_click
""".stripMargin
)
Node2vec.setup(context, param)
Node2vec.load(tidb_inupt)
.initTransitionProb()
.randomWalk()
.embedding()
val node2vector = context.parallelize(Word2vec.getVectors.toList)
.map { case (nodeId, vector) =>
(nodeId.toLong, vector.map(x => x.toDouble))
}
val id2Node = Node2vec.node2id.map{ case (strNode, index) =>
(index, strNode)
}
println("get id2node")
println(id2Node.first())
val node2vec_2 = node2vector.join(id2Node)
.map { case (nodeId, (vector, name)) => (name,vector) }
.repartition(200)
//2. compute similar cid and then take top k
val storageLevel = StorageLevel.MEMORY_AND_DISK
val indexed = node2vec_2.zipWithIndex.persist(storageLevel)
// create indexed row matrix where every row represents one word
val rows = indexed.map {
case ((word, features), index) =>
IndexedRow(index, Vectors.dense(features))
}
// store index for later re-mapping (index to word)
val index = indexed.map {
case ((word, features), index) =>
(index, word)
}.persist(storageLevel)
// create an input matrix from all rows and run lsh on it
val matrix = new IndexedRowMatrix(rows)
val lsh = new Lsh(
minCosineSimilarity = 0.5,
dimensions = 20,
numNeighbours = 200,
numPermutations = 10,
partitions = 200,
storageLevel = storageLevel
)
val similarityMatrix = lsh.join(matrix)
import sc.implicits._
// remap both ids back to words
val remapFirst = similarityMatrix.entries.keyBy(_.i).join(index).values
val remapSecond = remapFirst.keyBy { case (entry, word1) => entry.j }.join(index).values.map {
case ((entry, word1), word2) =>
(word1, word2, entry.value)
}
remapSecond.take(20).foreach(println)
val w1 = Window.orderBy($"score")
val score_result = remapSecond.toDF("cid1","cid2","score").withColumn("id",row_number().over(w1))
GmeiConfig.writeToJDBCTable(score_result, table="cid_pairs_cosine_distince", SaveMode.Overwrite)
// group by neighbours to get a list of similar words and then take top k
val result = remapSecond.groupBy(_._1).map {
case (word1, similarWords) =>
// sort by score desc. and take top 10 entries
val similar = similarWords.toSeq.sortBy(-1 * _._3).take(10).map(_._2).mkString(",")
(word1,s"$similar")
}
// print out the results for the first 10 words
result.take(20).foreach(println)
val w2 = Window.orderBy($"cid")
val similar_result = result.toDF("cid","similarity_cid").withColumn("id",row_number().over(w2))
GmeiConfig.writeToJDBCTable(similar_result, table="cid_similarity_matrix", SaveMode.Overwrite)
}
} getOrElse {
sys.exit(1)
}
}
package com.gmei
import java.io.Serializable
import scala.util.Try
import scala.collection.mutable.ArrayBuffer
import org.slf4j.{Logger, LoggerFactory}
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.graphx.{EdgeTriplet, Graph, _}
import com.gmei.graph.{EdgeAttr, GraphOps, NodeAttr}
import org.apache.spark.sql.DataFrame
object Node2vec extends Serializable {
lazy val logger: Logger = LoggerFactory.getLogger(getClass.getName);
var context: SparkContext = null
var config: Main.Params = null
var node2id: RDD[(String, Long)] = null
var indexedEdges: RDD[Edge[EdgeAttr]] = _
var indexedNodes: RDD[(VertexId, NodeAttr)] = _
var graph: Graph[NodeAttr, EdgeAttr] = _
var randomWalkPaths: RDD[(Long, ArrayBuffer[Long])] = null
def setup(context: SparkContext, param: Main.Params): this.type = {
this.context = context
this.config = param
this
}
def load(tidb_input: DataFrame): this.type = {
val bcMaxDegree = context.broadcast(config.degree)
val bcEdgeCreator = config.directed match {
case true => context.broadcast(GraphOps.createDirectedEdge)
case false => context.broadcast(GraphOps.createUndirectedEdge)
}
val inputTriplets: RDD[(Long, Long, Double)] = config.indexed match {
// case true => readIndexedGraph(config.input)
case false => indexingGraph(tidb_input)
}
indexedNodes = inputTriplets.flatMap { case (srcId, dstId, weight) =>
bcEdgeCreator.value.apply(srcId, dstId, weight)
}.reduceByKey(_++_).map { case (nodeId, neighbors: Array[(VertexId, Double)]) =>
var neighbors_ = neighbors
if (neighbors_.length > bcMaxDegree.value) {
neighbors_ = neighbors.sortWith{ case (left, right) => left._2 > right._2 }.slice(0, bcMaxDegree.value)
}
(nodeId, NodeAttr(neighbors = neighbors_.distinct))
}.repartition(200).cache
indexedEdges = indexedNodes.flatMap { case (srcId, clickNode) =>
clickNode.neighbors.map { case (dstId, weight) =>
Edge(srcId, dstId, EdgeAttr())
}
}.repartition(200).cache
this
}
def initTransitionProb(): this.type = {
val bcP = context.broadcast(config.p)
val bcQ = context.broadcast(config.q)
graph = Graph(indexedNodes, indexedEdges)
.mapVertices[NodeAttr] { case (vertexId, clickNode) =>
val (j, q) = GraphOps.setupAlias(clickNode.neighbors)
val nextNodeIndex = GraphOps.drawAlias(j, q)
clickNode.path = Array(vertexId, clickNode.neighbors(nextNodeIndex)._1)
clickNode
}
.mapTriplets { edgeTriplet: EdgeTriplet[NodeAttr, EdgeAttr] =>
val (j, q) = GraphOps.setupEdgeAlias(bcP.value, bcQ.value)(edgeTriplet.srcId, edgeTriplet.srcAttr.neighbors, edgeTriplet.dstAttr.neighbors)
edgeTriplet.attr.J = j
edgeTriplet.attr.q = q
edgeTriplet.attr.dstNeighbors = edgeTriplet.dstAttr.neighbors.map(_._1)
edgeTriplet.attr
}.cache
this
}
def randomWalk(): this.type = {
val edge2attr = graph.triplets.map { edgeTriplet =>
(s"${edgeTriplet.srcId}${edgeTriplet.dstId}", edgeTriplet.attr)
}.repartition(200).cache
edge2attr.first
for (iter <- 0 until config.numWalks) {
var prevWalk: RDD[(Long, ArrayBuffer[Long])] = null
var randomWalk = graph.vertices.map { case (nodeId, clickNode) =>
val pathBuffer = new ArrayBuffer[Long]()
pathBuffer.append(clickNode.path:_*)
(nodeId, pathBuffer)
}.cache
var activeWalks = randomWalk.first
graph.unpersist(blocking = false)
graph.edges.unpersist(blocking = false)
for (walkCount <- 0 until config.walkLength) {
prevWalk = randomWalk
randomWalk = randomWalk.map { case (srcNodeId, pathBuffer) =>
val prevNodeId = pathBuffer(pathBuffer.length - 2)
val currentNodeId = pathBuffer.last
(s"$prevNodeId$currentNodeId", (srcNodeId, pathBuffer))
}.join(edge2attr).map { case (edge, ((srcNodeId, pathBuffer), attr)) =>
try {
val nextNodeIndex = GraphOps.drawAlias(attr.J, attr.q)
val nextNodeId = attr.dstNeighbors(nextNodeIndex)
pathBuffer.append(nextNodeId)
(srcNodeId, pathBuffer)
} catch {
case e: Exception => throw new RuntimeException(e.getMessage)
}
}.cache
activeWalks = randomWalk.first()
prevWalk.unpersist(blocking=false)
}
if (randomWalkPaths != null) {
val prevRandomWalkPaths = randomWalkPaths
randomWalkPaths = randomWalkPaths.union(randomWalk).cache()
randomWalkPaths.first
prevRandomWalkPaths.unpersist(blocking = false)
} else {
randomWalkPaths = randomWalk
}
}
this
}
def embedding(): this.type = {
val randomPaths = randomWalkPaths.map { case (vertexId, pathBuffer) =>
Try(pathBuffer.map(_.toString).toIterable).getOrElse(null)
}.filter(_!=null)
Word2vec.setup(context, config).fit(randomPaths)
this
}
// def save(): this.type = {
// this.saveRandomPath()
// .saveModel()
// .saveVectors()
// }
// def saveRandomPath(): this.type = {
// randomWalkPaths
// .map { case (vertexId, pathBuffer) =>
// Try(pathBuffer.mkString("\t")).getOrElse(null)
// }
// .filter(x => x != null && x.replaceAll("\\s", "").length > 0)
// .repartition(200)
// .saveAsTextFile(config.output)
//
// this
// }
//
// def saveModel(): this.type = {
// Word2vec.save(config.output)
//
// this
// }
//
// def saveVectors(): this.type = {
// val node2vector = context.parallelize(Word2vec.getVectors.toList)
// .map { case (nodeId, vector) =>
// (nodeId.toLong, vector.mkString(","))
// }
//
// if (this.node2id != null) {
// val id2Node = this.node2id.map{ case (strNode, index) =>
// (index, strNode)
// }
//
// node2vector.join(id2Node)
// .map { case (nodeId, (vector, name)) => s"$name\t$vector" }
// .repartition(200)
// .saveAsTextFile(s"${config.output}.emb")
// } else {
// node2vector.map { case (nodeId, vector) => s"$nodeId\t$vector" }
// .repartition(200)
// .saveAsTextFile(s"${config.output}.emb")
// }
//
// this
// }
//
def cleanup(): this.type = {
node2id.unpersist(blocking = false)
indexedEdges.unpersist(blocking = false)
indexedNodes.unpersist(blocking = false)
graph.unpersist(blocking = false)
randomWalkPaths.unpersist(blocking = false)
this
}
def loadNode2Id(node2idPath: String): this.type = {
try {
this.node2id = context.textFile(config.nodePath).map { node2index =>
val Array(strNode, index) = node2index.split("\\s")
(strNode, index.toLong)
}
} catch {
case e: Exception => logger.info("Failed to read node2index file.")
this.node2id = null
}
this
}
// def readIndexedGraph(tripletPath: String) = {
// val bcWeighted = context.broadcast(config.weighted)
//
// val rawTriplets = context.textFile(tripletPath)
// if (config.nodePath == null) {
// this.node2id = createNode2Id(rawTriplets.map { triplet =>
// val parts = triplet.split("\\s")
// (parts.head, parts(1), -1)
// })
// } else {
// loadNode2Id(config.nodePath)
// }
//
// rawTriplets.map { triplet =>
// val parts = triplet.split("\\s")
// val weight = bcWeighted.value match {
// case true => Try(parts.last.toDouble).getOrElse(1.0)
// case false => 1.0
// }
//
// (parts.head.toLong, parts(1).toLong, weight)
// }
// }
def indexingGraph(tidb_input: DataFrame): RDD[(Long, Long, Double)] = {
val rawEdges = tidb_input.rdd.map { triplet =>
val parts = (triplet.getAs[String]("service_id"), triplet.getAs[String]("cid"))
Try {
(parts._1, parts._2, Try(parts._2.toDouble).getOrElse(1.0))
}.getOrElse(null)
}.filter(_!=null)
this.node2id = createNode2Id(rawEdges)
rawEdges.map { case (src, dst, weight) =>
(src, (dst, weight))
}.join(node2id).map { case (src, (edge: (String, Double), srcIndex: Long)) =>
try {
val (dst: String, weight: Double) = edge
(dst, (srcIndex, weight))
} catch {
case e: Exception => null
}
}.filter(_!=null).join(node2id).map { case (dst, (edge: (Long, Double), dstIndex: Long)) =>
try {
val (srcIndex, weight) = edge
(srcIndex, dstIndex, weight)
} catch {
case e: Exception => null
}
}.filter(_!=null)
}
def createNode2Id[T <: Any](triplets: RDD[(String, String, T)]) = triplets.flatMap { case (src, dst, weight) =>
Try(Array(src, dst)).getOrElse(Array.empty[String])
}.distinct().zipWithIndex()
}
package com.gmei
import scala.collection.mutable
import scala.reflect.ClassTag
import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD
/**
* NOTE: both classes are copied from mllib and slightly modified since these classes are mllib private!
* Modified lines are marked with comments
*/
class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T], val offset: Int)
extends Partition with Serializable {
override val index: Int = idx
}
/**
* Represents an RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding
* window over them. The ordering is first based on the partition index and then the ordering of
* items within each partition. This is similar to sliding in Scala collections, except that it
* becomes an empty RDD if the window size is greater than the total number of items. It needs to
* trigger a Spark job if the parent RDD has more than one partitions. To make this operation
* efficient, the number of items per partition should be larger than the window size and the
* window size should be small, e.g., 2.
*
* @param parent the parent RDD
* @param windowSize the window size, must be greater than 1
* @param step step size for windows
*
* @see [[org.apache.spark.mllib.rdd.RDDFunctions.sliding(Int, Int)*]]
* @see [[scala.collection.IterableLike.sliding(Int, Int)*]]
*/
class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int, val step: Int)
extends RDD[Array[T]](parent) {
require(windowSize > 0 && step > 0 && !(windowSize == 1 && step == 1),
"Window size and step must be greater than 0, " +
s"and they cannot be both 1, but got windowSize = $windowSize and step = $step.")
override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = {
val part = split.asInstanceOf[SlidingRDDPartition[T]]
(firstParent[T].iterator(part.prev, context) ++ part.tail)
.drop(part.offset)
.sliding(windowSize, step)
.withPartial(true) // modified from false -> true
.map(_.toArray)
}
override def getPreferredLocations(split: Partition): Seq[String] =
firstParent[T].preferredLocations(split.asInstanceOf[SlidingRDDPartition[T]].prev)
override def getPartitions: Array[Partition] = {
val parentPartitions = parent.partitions
val n = parentPartitions.length
if (n == 0) {
Array.empty
} else if (n == 1) {
Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty, 0))
} else {
val w1 = windowSize - 1
// Get partition sizes and first w1 elements.
val (sizes, heads) = parent.mapPartitions { iter =>
val w1Array = iter.take(w1).toArray
Iterator.single((w1Array.length + iter.length, w1Array))
}.collect().unzip
val partitions = mutable.ArrayBuffer.empty[SlidingRDDPartition[T]]
var i = 0
var cumSize = 0
var partitionIndex = 0
while (i < n) {
val mod = cumSize % step
val offset = if (mod == 0) 0 else step - mod
val size = sizes(i)
if (offset < size) {
val tail = mutable.ListBuffer.empty[T]
// Keep appending to the current tail until it has w1 elements.
var j = i + 1
while (j < n && tail.length < w1) {
tail ++= heads(j).take(w1 - tail.length)
j += 1
}
// if (sizes(i) + tail.length >= offset + windowSize) { // modified: removed
partitions +=
new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail, offset)
partitionIndex += 1
// } // modified: removed
}
cumSize += size
i += 1
}
partitions.toArray
}
}
// TODO: Override methods such as aggregate, which only requires one Spark job.
}
\ No newline at end of file
package com.gmei
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.mllib.linalg.Vector
/**
* interface defining similarity measurement between 2 vectors
*/
trait VectorDistance extends Serializable {
def apply(vecA: Vector, vecB: Vector): Double
}
/**
* implementation of [[VectorDistance]] that computes cosine similarity
* between two vectors
*/
object Cosine extends VectorDistance {
def apply(vecA: Vector, vecB: Vector): Double = {
val v1 = vecA.toArray.map(_.toFloat)
val v2 = vecB.toArray.map(_.toFloat)
apply(v1, v2)
}
def apply(vecA: Array[Float], vecB: Array[Float]): Double = {
val n = vecA.length
val norm1 = blas.snrm2(n, vecA, 1)
val norm2 = blas.snrm2(n, vecB, 1)
if (norm1 == 0 || norm2 == 0) return 0.0
blas.sdot(n, vecA, 1, vecB, 1) / norm1 / norm2
}
}
package com.gmei
import org.apache.spark.SparkContext
import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel}
import org.apache.spark.rdd.RDD
object Word2vec extends Serializable {
var context: SparkContext = null
var word2vec = new Word2Vec()
var model: Word2VecModel = null
def setup(context: SparkContext, param: Main.Params): this.type = {
this.context = context
/**
* model = sg
* update = hs
*/
word2vec.setLearningRate(param.lr)
.setNumIterations(param.iter)
.setNumPartitions(param.numPartition)
.setMinCount(0)
.setVectorSize(param.dim)
val word2vecWindowField = word2vec.getClass.getDeclaredField("org$apache$spark$mllib$feature$Word2Vec$$window")
word2vecWindowField.setAccessible(true)
word2vecWindowField.setInt(word2vec, param.window)
this
}
def read(path: String): RDD[Iterable[String]] = {
context.textFile(path).repartition(200).map(_.split("\\s").toSeq)
}
def fit(input: RDD[Iterable[String]]): this.type = {
model = word2vec.fit(input)
this
}
def save(outputPath: String): this.type = {
model.save(context, s"$outputPath.bin")
this
}
def load(path: String): this.type = {
model = Word2VecModel.load(context, path)
this
}
def getVectors = this.model.getVectors
}
package com.gmei.graph
import scala.collection.mutable.ArrayBuffer
object GraphOps {
def setupAlias(nodeWeights: Array[(Long, Double)]): (Array[Int], Array[Double]) = {
val K = nodeWeights.length
val J = Array.fill(K)(0)
val q = Array.fill(K)(0.0)
val smaller = new ArrayBuffer[Int]()
val larger = new ArrayBuffer[Int]()
val sum = nodeWeights.map(_._2).sum
nodeWeights.zipWithIndex.foreach { case ((nodeId, weight), i) =>
q(i) = K * weight / sum
if (q(i) < 1.0) {
smaller.append(i)
} else {
larger.append(i)
}
}
while (smaller.nonEmpty && larger.nonEmpty) {
val small = smaller.remove(smaller.length - 1)
val large = larger.remove(larger.length - 1)
J(small) = large
q(large) = q(large) + q(small) - 1.0
if (q(large) < 1.0) smaller.append(large)
else larger.append(large)
}
(J, q)
}
def setupEdgeAlias(p: Double = 1.0, q: Double = 1.0)(srcId: Long, srcNeighbors: Array[(Long, Double)], dstNeighbors: Array[(Long, Double)]): (Array[Int], Array[Double]) = {
val neighbors_ = dstNeighbors.map { case (dstNeighborId, weight) =>
var unnormProb = weight / q
if (srcId == dstNeighborId) unnormProb = weight / p
else if (srcNeighbors.exists(_._1 == dstNeighborId)) unnormProb = weight
(dstNeighborId, unnormProb)
}
setupAlias(neighbors_)
}
def drawAlias(J: Array[Int], q: Array[Double]): Int = {
val K = J.length
val kk = math.floor(math.random * K).toInt
if (math.random < q(kk)) kk
else J(kk)
}
lazy val createUndirectedEdge = (srcId: Long, dstId: Long, weight: Double) => {
Array(
(srcId, Array((dstId, weight))),
(dstId, Array((srcId, weight)))
)
}
lazy val createDirectedEdge = (srcId: Long, dstId: Long, weight: Double) => {
Array(
(srcId, Array((dstId, weight)))
)
}
}
package com.gmei
import java.io.Serializable
package object graph {
case class NodeAttr(var neighbors: Array[(Long, Double)] = Array.empty[(Long, Double)],
var path: Array[Long] = Array.empty[Long]) extends Serializable
case class EdgeAttr(var dstNeighbors: Array[Long] = Array.empty[Long],
var J: Array[Int] = Array.empty[Int],
var q: Array[Double] = Array.empty[Double]) extends Serializable
}
package com.gmei.lib
import scala.reflect.runtime.universe._
/**
* Abstract class for parameter case classes.
* This overrides the [[toString]] method to print all case class fields by name and value.
* @tparam T Concrete parameter class.
*/
abstract class AbstractParams[T: TypeTag] {
private def tag: TypeTag[T] = typeTag[T]
/**
* Finds all case class fields in concrete class instance, and outputs them in JSON-style format:
* {
* [field name]:\t[field value]\n
* [field name]:\t[field value]\n
* ...
* }
*/
override def toString: String = {
val tpe = tag.tpe
val allAccessors = tpe.decls.collect {
case m: MethodSymbol if m.isCaseAccessor => m
}
val mirror = runtimeMirror(getClass.getClassLoader)
val instanceMirror = mirror.reflect(this)
allAccessors.map { f =>
val paramName = f.name.toString
val fieldMirror = instanceMirror.reflectField(f)
val paramValue = fieldMirror.get
s" $paramName:\t$paramValue"
}.mkString("{\n", ",\n", "\n}")
}
}
\ No newline at end of file
package com.gmei
import java.util.Random
import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, IndexedRow, IndexedRowMatrix}
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector}
import org.apache.spark.rdd.RDD
package object lsh {
/**
* An id with it's hash encoding.
*
*/
final case class SparseSignature(index: Long, bitSet: BitSet) extends Ordered[SparseSignature] {
override def compare(that: SparseSignature): Int = bitSetComparator(this.bitSet, that.bitSet)
}
/**
* An id with it's hash encoding and original vector.
*
*/
final case class Signature(index: Long, vector: Vector, bitSet: BitSet) extends Ordered[Signature] {
override def compare(that: Signature): Int = bitSetComparator(this.bitSet, that.bitSet)
}
/**
* Compares two bit sets according to the first different bit
*
*/
def bitSetComparator(a: BitSet, b: BitSet): Int = {
val xor = a ^ b
val firstDifference = xor.nextSetBit(0)
if (firstDifference >= 0) {
if (a.get(firstDifference)) // if the difference is set to 1 on a
1
else
-1
} else {
0
}
}
/**
* Returns a string representation of a BitSet
*/
def bitSetToString(bs: BitSet): String = bs.iterator.mkString(",")
/**
* Returns a local k by d matrix with random gaussian entries mean=0.0 and
* std=1.0
*
* This is a k by d matrix as it is multiplied by the input matrix
*
*/
def localRandomMatrix(d: Int, numFeatures: Int): Matrix = {
val randomGenerator = new Random()
val values = Array.fill(numFeatures * d)(randomGenerator.nextGaussian())
Matrices.dense(numFeatures, d, values)
}
/**
* Converts a given input matrix to a bit set representation using random hyperplanes
*
*/
def matrixToBitSet(inputMatrix: IndexedRowMatrix, localRandomMatrix: Matrix): RDD[Signature] = {
val bitSets = inputMatrix.multiply(localRandomMatrix).rows.map {
indexedRow =>
(indexedRow.index, vectorToBitSet(indexedRow.vector))
}
val originalVectors = inputMatrix.rows.map { row => (row.index, row.vector) }
bitSets.join(originalVectors).map {
case (id, (bitSet, vector)) =>
Signature(id, vector, bitSet)
}
}
/**
* Converts a given input matrix to a bit set representation using random hyperplanes
*
*/
def matrixToBitSetSparse(inputMatrix: IndexedRowMatrix, localRandomMatrix: Matrix): RDD[SparseSignature] = {
inputMatrix.multiply(localRandomMatrix).rows.map {
indexedRow: IndexedRow =>
val bitSet = vectorToBitSet(indexedRow.vector)
SparseSignature(indexedRow.index, bitSet)
}
}
/**
* Converts a vector to a bit set by replacing all values of x with sign(x)
*
*/
def vectorToBitSet(vector: Vector): BitSet = {
val bitSet = new BitSet(vector.size)
vector.toArray.zipWithIndex.map {
case ((value: Double, index: Int)) =>
if (math.signum(value) > 0)
bitSet.set(index)
}
bitSet
}
/**
* Approximates the cosine distance of two bit sets using their hamming
* distance
*
*/
def hammingToCosine(hammingDistance: Int, d: Double): Double = {
val pr = 1.0 - (hammingDistance / d)
math.cos((1.0 - pr) * math.Pi)
}
/**
* Returns the hamming distance between two bit vectors
*
*/
def hamming(vec1: BitSet, vec2: BitSet): Int = {
(vec1 ^ vec2).cardinality()
}
/**
* Compares two bit sets for their equality
*
*/
def bitSetIsEqual(vec1: BitSet, vec2: BitSet): Boolean = {
hamming(vec1, vec2) == 0
}
/**
* Take distinct matrix entry values based on the indices only.
* The actual values are discarded.
*
*/
def distinct(matrix: RDD[MatrixEntry]): RDD[MatrixEntry] = {
matrix.keyBy(m => (m.i, m.j)).reduceByKey((x, y) => x).values
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment