Commit ca2ed678 authored by 高雅喆's avatar 高雅喆

add libraryDependencies cosine-lsh-join-spark_2.10-1.0.6.jar

parent b7043a9d
Copyright 2014 Typesafe, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
## spark-hadoop jar包
http://mirror.bit.edu.cn/apache/spark/spark-2.2.1/spark-2.2.1-bin-hadoop2.7.tgz
## tispark jar包驱动
http://download.pingcap.org/tispark-1.0-RC1-jar-with-dependencies.jar
## jdk >= 1.8.0
http://www.oracle.com/technetwork/java/javase/downloads/jdk8-downloads-2133151.html
## mysql 驱动
http://central.maven.org/maven2/mysql/mysql-connector-java/5.1.44/mysql-connector-java-5.1.44.jar
## 放置jar包
mysql jar
tispark jar
/usr/local/Cellar/apache-spark/2.1.2/jars
## spark 配置
/usr/local/Cellar/apache-spark/2.1.2/conf
spark-env.sh
export SPARK_EXECUTOR_MEMORY=3g
export SPARK_WORKER_MEMORY=3g
export SPARK_WORKER_CORES=2
spark-defaults.conf
spark.tispark.pd.addresses 192.168.15.11:2379
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.6")
#Activator-generated Properties
#Thu Mar 08 19:32:05 CST 2018
template.uuid=e17acfbb-1ff5-41f5-b8cf-2c40be6a8340
sbt.version=1.0.4
sbt.version = 0.13.15
\ No newline at end of file
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.6")
\ No newline at end of file
#### source table naming rules
src_$dbname_$tablenmae, eg
src_mimas_prod_api_preoperationimage
#### etl table naming rules
mid_$tablename, eg
mid_diary_tags
#### theme table naming rules
biz_$tablename, eg
biz_user_index_diary_list_ctr
\ No newline at end of file
package com.gmei
import java.util.Arrays
import scala.util.hashing.MurmurHash3
/**
* A simple, fixed-size bit set implementation. This implementation is fast because it avoids
* safety/bound checking.
*/
class BitSet(val numBits: Int) extends Serializable {
private val words = new Array[Long](bit2words(numBits))
private val numWords = words.length
/**
* Compute the capacity (number of bits) that can be represented
* by this bitset.
*/
def capacity: Int = numWords * 64
/**
* Clear all set bits.
*/
def clear(): Unit = Arrays.fill(words, 0)
/**
* Set all the bits up to a given index
*/
def setUntil(bitIndex: Int): Unit = {
val wordIndex = bitIndex >> 6 // divide by 64
Arrays.fill(words, 0, wordIndex, -1)
if(wordIndex < words.length) {
// Set the remaining bits (note that the mask could still be zero)
val mask = ~(-1L << (bitIndex & 0x3f))
words(wordIndex) |= mask
}
}
/**
* Clear all the bits up to a given index
*/
def clearUntil(bitIndex: Int): Unit = {
val wordIndex = bitIndex >> 6 // divide by 64
Arrays.fill(words, 0, wordIndex, 0)
if(wordIndex < words.length) {
// Clear the remaining bits
val mask = -1L << (bitIndex & 0x3f)
words(wordIndex) &= mask
}
}
/**
* Compute the bit-wise AND of the two sets returning the
* result.
*/
def &(other: BitSet): BitSet = {
val newBS = new BitSet(math.max(capacity, other.capacity))
val smaller = math.min(numWords, other.numWords)
assert(newBS.numWords >= numWords)
assert(newBS.numWords >= other.numWords)
var ind = 0
while( ind < smaller ) {
newBS.words(ind) = words(ind) & other.words(ind)
ind += 1
}
newBS
}
/**
* Compute the bit-wise OR of the two sets returning the
* result.
*/
def |(other: BitSet): BitSet = {
val newBS = new BitSet(math.max(capacity, other.capacity))
assert(newBS.numWords >= numWords)
assert(newBS.numWords >= other.numWords)
val smaller = math.min(numWords, other.numWords)
var ind = 0
while( ind < smaller ) {
newBS.words(ind) = words(ind) | other.words(ind)
ind += 1
}
while( ind < numWords ) {
newBS.words(ind) = words(ind)
ind += 1
}
while( ind < other.numWords ) {
newBS.words(ind) = other.words(ind)
ind += 1
}
newBS
}
/**
* Compute the symmetric difference by performing bit-wise XOR of the two sets returning the
* result.
*/
def ^(other: BitSet): BitSet = {
val newBS = new BitSet(math.max(capacity, other.capacity))
val smaller = math.min(numWords, other.numWords)
var ind = 0
while (ind < smaller) {
newBS.words(ind) = words(ind) ^ other.words(ind)
ind += 1
}
if (ind < numWords) {
Array.copy( words, ind, newBS.words, ind, numWords - ind )
}
if (ind < other.numWords) {
Array.copy( other.words, ind, newBS.words, ind, other.numWords - ind )
}
newBS
}
/**
* Compute the difference of the two sets by performing bit-wise AND-NOT returning the
* result.
*/
def andNot(other: BitSet): BitSet = {
val newBS = new BitSet(capacity)
val smaller = math.min(numWords, other.numWords)
var ind = 0
while (ind < smaller) {
newBS.words(ind) = words(ind) & ~other.words(ind)
ind += 1
}
if (ind < numWords) {
Array.copy( words, ind, newBS.words, ind, numWords - ind )
}
newBS
}
/**
* Sets the bit at the specified index to true.
* @param index the bit index
*/
def set(index: Int) {
val bitmask = 1L << (index & 0x3f) // mod 64 and shift
words(index >> 6) |= bitmask // div by 64 and mask
}
def unset(index: Int) {
val bitmask = 1L << (index & 0x3f) // mod 64 and shift
words(index >> 6) &= ~bitmask // div by 64 and mask
}
def flip(index: Int) {
val bitmask = 1L << (index & 0x3f)
words(index >> 6) ^= bitmask
}
/**
* Return the value of the bit with the specified index. The value is true if the bit with
* the index is currently set in this BitSet; otherwise, the result is false.
*
* @param index the bit index
* @return the value of the bit with the specified index
*/
def get(index: Int): Boolean = {
val bitmask = 1L << (index & 0x3f) // mod 64 and shift
(words(index >> 6) & bitmask) != 0 // div by 64 and mask
}
/**
* Get an iterator over the set bits.
*/
def iterator: Iterator[Int] = new Iterator[Int] {
var ind = nextSetBit(0)
override def hasNext: Boolean = ind >= 0
override def next(): Int = {
val tmp = ind
ind = nextSetBit(ind + 1)
tmp
}
}
/** Return the number of bits set to true in this BitSet. */
def cardinality(): Int = {
var sum = 0
var i = 0
while (i < numWords) {
sum += java.lang.Long.bitCount(words(i))
i += 1
}
sum
}
/**
* Returns the index of the first bit that is set to true that occurs on or after the
* specified starting index. If no such bit exists then -1 is returned.
*
* To iterate over the true bits in a BitSet, use the following loop:
*
* for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i+1)) {
* // operate on index i here
* }
*
* @param fromIndex the index to start checking from (inclusive)
* @return the index of the next set bit, or -1 if there is no such bit
*/
def nextSetBit(fromIndex: Int): Int = {
var wordIndex = fromIndex >> 6
if (wordIndex >= numWords) {
return -1
}
// Try to find the next set bit in the current word
val subIndex = fromIndex & 0x3f
var word = words(wordIndex) >> subIndex
if (word != 0) {
return (wordIndex << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word)
}
// Find the next set bit in the rest of the words
wordIndex += 1
while (wordIndex < numWords) {
word = words(wordIndex)
if (word != 0) {
return (wordIndex << 6) + java.lang.Long.numberOfTrailingZeros(word)
}
wordIndex += 1
}
-1
}
/** Return the number of longs it would take to hold numBits. */
private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1
/** Override hashCode method to allow BitSet be used as an RDD key */
override def hashCode() = MurmurHash3.arrayHash(words)
}
object BitSet {
def apply(bitSet: BitSet) = {
val bitSetCopy = new BitSet(bitSet.numBits)
bitSet.iterator.foreach(ix => bitSetCopy.set(ix))
bitSetCopy
}
}
\ No newline at end of file
......@@ -30,14 +30,14 @@ object GmeiConfig extends Serializable {
this
}
//这种情况下的param1的值为null,所以会出现空指针异常的错误
// val env = param1.env match {
// case "prod" => ENV.PROD
// case "dev" => ENV.DEV
// case "pre" => ENV.PRE
// case _ => ENV.DEV
// }
// val config = initConfig(env)
//这种情况下的param1的值为null,所以会出现空指针异常的错误
// val env = param1.env match {
// case "prod" => ENV.PROD
// case "dev" => ENV.DEV
// case "pre" => ENV.PRE
// case _ => ENV.DEV
// }
// val config = initConfig(env)
def initConfig(env: ENV) = {
lazy val c = ConfigFactory.load()
......@@ -77,7 +77,7 @@ object GmeiConfig extends Serializable {
prop.put("driver", "com.mysql.jdbc.Driver")
prop.put("useSSL", "false")
prop.put("isolationLevel", "NONE")
prop.put("isTruncate", "true")
prop.put("truncate", "true")
// save to mysql/tidb
df.repartition(128).write.mode(saveModel)
.option(JDBCOptions.JDBC_BATCH_INSERT_SIZE, 300)
......
package com.gmei
import org.apache.spark.mllib.linalg.distributed.{CoordinateMatrix, IndexedRowMatrix}
trait Joiner {
/**
* Find the k nearest neighbors from a data set for every other object in the
* same data set. Implementations may be either exact or approximate.
*
* @param matrix a row oriented matrix. Each row in the matrix represents
* an item in the data set. Items are identified by their
* matrix index.
* @return a similarity matrix with MatrixEntry(itemA, itemB, similarity).
*
*/
def join(matrix: IndexedRowMatrix): CoordinateMatrix
}
trait QueryJoiner {
/**
* Find the k nearest neighbours in catalogMatrix for each entry in queryMatrix.
* Implementations may be either exact or approximate.
*
* @param queryMatrix a row oriented matrix. Each row in the matrix represents
* an item in the data set. Items are identified by their
* matrix index.
* @param catalogMatrix a row oriented matrix. Each row in the matrix represents
* an item in the data set. Items are identified by their
* matrix index.
* @return a similarity matrix with MatrixEntry(queryA, catalogB, similarity).
*/
def join(queryMatrix: IndexedRowMatrix, catalogMatrix: IndexedRowMatrix): CoordinateMatrix
}
package com.gmei
import org.apache.spark.mllib.linalg.distributed._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import com.gmei.lsh.{localRandomMatrix,matrixToBitSet,distinct,Signature}
/**
* Lsh implementation as described in 'Randomized Algorithms and NLP: Using
* Locality Sensitive Hash Function for High Speed Noun Clustering' by
* Ravichandran et al. See original publication for a detailed description of
* the parameters.
*
* @see http://dl.acm.org/citation.cfm?id=1219917
* @param minCosineSimilarity minimum similarity two items need to have
* otherwise they are discarded from the result set
* @param dimensions number of random vectors (hyperplanes) to generate bit
* vectors of length d
* @param numNeighbours beam factor e.g. how many neighbours are considered
* in the sliding window
* @param numPermutations number of times bitsets are permuted
*
*/
class Lsh(minCosineSimilarity: Double,
dimensions: Int,
numNeighbours: Int,
numPermutations: Int,
partitions: Int = 200,
storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK)
extends Joiner with Serializable {
override def join(inputMatrix: IndexedRowMatrix): CoordinateMatrix = {
val numFeatures = inputMatrix.numCols().toInt
val randomMatrix = localRandomMatrix(dimensions, numFeatures)
val signatures = matrixToBitSet(inputMatrix, randomMatrix).
repartition(partitions).
persist(storageLevel)
val neighbours = 0 until numPermutations map {
x =>
val permutation = generatePermutation(dimensions)
val permuted = permuteBitSet(signatures, permutation, dimensions)
val orderedSignatures = orderByBitSet(permuted)
val slidingWindow = createSlidingWindow(orderedSignatures, numNeighbours)
findNeighbours(slidingWindow, minCosineSimilarity)
}
val mergedNeighbours = neighbours.reduce(_ ++ _)
new CoordinateMatrix(distinct(mergedNeighbours))
}
/**
* Permutes a signatures by a given permutation
*
*/
def permuteBitSet(signatures: RDD[Signature], permutation: Iterable[Int], d: Int): RDD[Signature] = {
signatures.map {
signature: Signature =>
val permutedBitSet = permuteBitSet(signature.bitSet, permutation, d)
signature.copy(bitSet = permutedBitSet)
}
}
/**
* Permutes a bit set representation of a vector by a given permutation
*/
def permuteBitSet(bitSet: BitSet, permutation: Iterable[Int], d: Int): BitSet = {
val permutationWithIndex = permutation.zipWithIndex
val newBitSet = new BitSet(d)
permutationWithIndex.foreach {
case ((newIndex: Int, oldIndex: Int)) =>
val oldBit = bitSet.get(oldIndex)
if (oldBit)
newBitSet.set(newIndex)
else
newBitSet.unset(newIndex)
}
newBitSet
}
/**
* Generates a random permutation of size n
*/
def generatePermutation(size: Int): Iterable[Int] = {
val indices = (0 until size).toArray
util.Random.shuffle(indices)
}
/**
* Orderes an RDD of signatures by their bit set representation
*/
def orderByBitSet(signatures: RDD[Signature]): RDD[Signature] = {
signatures.sortBy(identity)
}
/**
* Creates a sliding window
*
*/
def createSlidingWindow(signatures: RDD[Signature], b: Int): RDD[Array[Signature]] = {
new SlidingRDD[Signature](signatures, b, b)
}
def findNeighbours(signatures: RDD[Array[Signature]], minCosineSimilarity: Double): RDD[MatrixEntry] = {
signatures.flatMap { signature: Array[Signature] =>
neighbours(signature, minCosineSimilarity)
}
}
/**
* Generate all pairs and emit if cosine of pair > minCosineSimilarity
*
*/
def neighbours(signatures: Array[Signature], minCosineSimilarity: Double): Iterator[MatrixEntry] = {
signatures.
sortBy(_.index). // sort in order to create an upper triangular matrix
combinations(2).
map {
case Array(first, second) =>
val cosine = Cosine(first.vector, second.vector)
MatrixEntry(first.index, second.index, cosine)
}.
filter(_.value >= minCosineSimilarity)
}
}
......@@ -10,6 +10,8 @@ import org.apache.spark.sql.{SaveMode,TiContext}
import org.apache.log4j.{Level, Logger}
import scopt.OptionParser
import com.gmei.lib.AbstractParams
import com.soundcloud.lsh.Lsh
......
......@@ -38,7 +38,7 @@ object Node2vec extends Serializable {
}
val inputTriplets: RDD[(Long, Long, Double)] = config.indexed match {
// case true => readIndexedGraph(config.input)
// case true => readIndexedGraph(config.input)
case false => indexingGraph(tidb_input)
}
......@@ -149,54 +149,54 @@ object Node2vec extends Serializable {
this
}
// def save(): this.type = {
// this.saveRandomPath()
// .saveModel()
// .saveVectors()
// }
// def saveRandomPath(): this.type = {
// randomWalkPaths
// .map { case (vertexId, pathBuffer) =>
// Try(pathBuffer.mkString("\t")).getOrElse(null)
// }
// .filter(x => x != null && x.replaceAll("\\s", "").length > 0)
// .repartition(200)
// .saveAsTextFile(config.output)
//
// this
// }
//
// def saveModel(): this.type = {
// Word2vec.save(config.output)
//
// this
// }
//
// def saveVectors(): this.type = {
// val node2vector = context.parallelize(Word2vec.getVectors.toList)
// .map { case (nodeId, vector) =>
// (nodeId.toLong, vector.mkString(","))
// }
//
// if (this.node2id != null) {
// val id2Node = this.node2id.map{ case (strNode, index) =>
// (index, strNode)
// }
//
// node2vector.join(id2Node)
// .map { case (nodeId, (vector, name)) => s"$name\t$vector" }
// .repartition(200)
// .saveAsTextFile(s"${config.output}.emb")
// } else {
// node2vector.map { case (nodeId, vector) => s"$nodeId\t$vector" }
// .repartition(200)
// .saveAsTextFile(s"${config.output}.emb")
// }
//
// this
// }
//
// def save(): this.type = {
// this.saveRandomPath()
// .saveModel()
// .saveVectors()
// }
// def saveRandomPath(): this.type = {
// randomWalkPaths
// .map { case (vertexId, pathBuffer) =>
// Try(pathBuffer.mkString("\t")).getOrElse(null)
// }
// .filter(x => x != null && x.replaceAll("\\s", "").length > 0)
// .repartition(200)
// .saveAsTextFile(config.output)
//
// this
// }
//
// def saveModel(): this.type = {
// Word2vec.save(config.output)
//
// this
// }
//
// def saveVectors(): this.type = {
// val node2vector = context.parallelize(Word2vec.getVectors.toList)
// .map { case (nodeId, vector) =>
// (nodeId.toLong, vector.mkString(","))
// }
//
// if (this.node2id != null) {
// val id2Node = this.node2id.map{ case (strNode, index) =>
// (index, strNode)
// }
//
// node2vector.join(id2Node)
// .map { case (nodeId, (vector, name)) => s"$name\t$vector" }
// .repartition(200)
// .saveAsTextFile(s"${config.output}.emb")
// } else {
// node2vector.map { case (nodeId, vector) => s"$nodeId\t$vector" }
// .repartition(200)
// .saveAsTextFile(s"${config.output}.emb")
// }
//
// this
// }
//
def cleanup(): this.type = {
node2id.unpersist(blocking = false)
indexedEdges.unpersist(blocking = false)
......@@ -221,29 +221,29 @@ object Node2vec extends Serializable {
this
}
// def readIndexedGraph(tripletPath: String) = {
// val bcWeighted = context.broadcast(config.weighted)
//
// val rawTriplets = context.textFile(tripletPath)
// if (config.nodePath == null) {
// this.node2id = createNode2Id(rawTriplets.map { triplet =>
// val parts = triplet.split("\\s")
// (parts.head, parts(1), -1)
// })
// } else {
// loadNode2Id(config.nodePath)
// }
//
// rawTriplets.map { triplet =>
// val parts = triplet.split("\\s")
// val weight = bcWeighted.value match {
// case true => Try(parts.last.toDouble).getOrElse(1.0)
// case false => 1.0
// }
//
// (parts.head.toLong, parts(1).toLong, weight)
// }
// }
// def readIndexedGraph(tripletPath: String) = {
// val bcWeighted = context.broadcast(config.weighted)
//
// val rawTriplets = context.textFile(tripletPath)
// if (config.nodePath == null) {
// this.node2id = createNode2Id(rawTriplets.map { triplet =>
// val parts = triplet.split("\\s")
// (parts.head, parts(1), -1)
// })
// } else {
// loadNode2Id(config.nodePath)
// }
//
// rawTriplets.map { triplet =>
// val parts = triplet.split("\\s")
// val weight = bcWeighted.value match {
// case true => Try(parts.last.toDouble).getOrElse(1.0)
// case false => 1.0
// }
//
// (parts.head.toLong, parts(1).toLong, weight)
// }
// }
def indexingGraph(tidb_input: DataFrame): RDD[(Long, Long, Double)] = {
......
package com.gmei
import scala.collection.mutable
import scala.reflect.ClassTag
import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD
/**
* NOTE: both classes are copied from mllib and slightly modified since these classes are mllib private!
* Modified lines are marked with comments
*/
class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T], val offset: Int)
extends Partition with Serializable {
override val index: Int = idx
}
/**
* Represents an RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding
* window over them. The ordering is first based on the partition index and then the ordering of
* items within each partition. This is similar to sliding in Scala collections, except that it
* becomes an empty RDD if the window size is greater than the total number of items. It needs to
* trigger a Spark job if the parent RDD has more than one partitions. To make this operation
* efficient, the number of items per partition should be larger than the window size and the
* window size should be small, e.g., 2.
*
* @param parent the parent RDD
* @param windowSize the window size, must be greater than 1
* @param step step size for windows
*
* @see [[org.apache.spark.mllib.rdd.RDDFunctions.sliding(Int, Int)*]]
* @see [[scala.collection.IterableLike.sliding(Int, Int)*]]
*/
class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int, val step: Int)
extends RDD[Array[T]](parent) {
require(windowSize > 0 && step > 0 && !(windowSize == 1 && step == 1),
"Window size and step must be greater than 0, " +
s"and they cannot be both 1, but got windowSize = $windowSize and step = $step.")
override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = {
val part = split.asInstanceOf[SlidingRDDPartition[T]]
(firstParent[T].iterator(part.prev, context) ++ part.tail)
.drop(part.offset)
.sliding(windowSize, step)
.withPartial(true) // modified from false -> true
.map(_.toArray)
}
override def getPreferredLocations(split: Partition): Seq[String] =
firstParent[T].preferredLocations(split.asInstanceOf[SlidingRDDPartition[T]].prev)
override def getPartitions: Array[Partition] = {
val parentPartitions = parent.partitions
val n = parentPartitions.length
if (n == 0) {
Array.empty
} else if (n == 1) {
Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty, 0))
} else {
val w1 = windowSize - 1
// Get partition sizes and first w1 elements.
val (sizes, heads) = parent.mapPartitions { iter =>
val w1Array = iter.take(w1).toArray
Iterator.single((w1Array.length + iter.length, w1Array))
}.collect().unzip
val partitions = mutable.ArrayBuffer.empty[SlidingRDDPartition[T]]
var i = 0
var cumSize = 0
var partitionIndex = 0
while (i < n) {
val mod = cumSize % step
val offset = if (mod == 0) 0 else step - mod
val size = sizes(i)
if (offset < size) {
val tail = mutable.ListBuffer.empty[T]
// Keep appending to the current tail until it has w1 elements.
var j = i + 1
while (j < n && tail.length < w1) {
tail ++= heads(j).take(w1 - tail.length)
j += 1
}
// if (sizes(i) + tail.length >= offset + windowSize) { // modified: removed
partitions +=
new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail, offset)
partitionIndex += 1
// } // modified: removed
}
cumSize += size
i += 1
}
partitions.toArray
}
}
// TODO: Override methods such as aggregate, which only requires one Spark job.
}
\ No newline at end of file
package com.gmei
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.mllib.linalg.Vector
/**
* interface defining similarity measurement between 2 vectors
*/
trait VectorDistance extends Serializable {
def apply(vecA: Vector, vecB: Vector): Double
}
/**
* implementation of [[VectorDistance]] that computes cosine similarity
* between two vectors
*/
object Cosine extends VectorDistance {
def apply(vecA: Vector, vecB: Vector): Double = {
val v1 = vecA.toArray.map(_.toFloat)
val v2 = vecB.toArray.map(_.toFloat)
apply(v1, v2)
}
def apply(vecA: Array[Float], vecB: Array[Float]): Double = {
val n = vecA.length
val norm1 = blas.snrm2(n, vecA, 1)
val norm2 = blas.snrm2(n, vecB, 1)
if (norm1 == 0 || norm2 == 0) return 0.0
blas.sdot(n, vecA, 1, vecB, 1) / norm1 / norm2
}
}
package com.gmei
import java.util.Random
import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, IndexedRow, IndexedRowMatrix}
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector}
import org.apache.spark.rdd.RDD
package object lsh {
/**
* An id with it's hash encoding.
*
*/
final case class SparseSignature(index: Long, bitSet: BitSet) extends Ordered[SparseSignature] {
override def compare(that: SparseSignature): Int = bitSetComparator(this.bitSet, that.bitSet)
}
/**
* An id with it's hash encoding and original vector.
*
*/
final case class Signature(index: Long, vector: Vector, bitSet: BitSet) extends Ordered[Signature] {
override def compare(that: Signature): Int = bitSetComparator(this.bitSet, that.bitSet)
}
/**
* Compares two bit sets according to the first different bit
*
*/
def bitSetComparator(a: BitSet, b: BitSet): Int = {
val xor = a ^ b
val firstDifference = xor.nextSetBit(0)
if (firstDifference >= 0) {
if (a.get(firstDifference)) // if the difference is set to 1 on a
1
else
-1
} else {
0
}
}
/**
* Returns a string representation of a BitSet
*/
def bitSetToString(bs: BitSet): String = bs.iterator.mkString(",")
/**
* Returns a local k by d matrix with random gaussian entries mean=0.0 and
* std=1.0
*
* This is a k by d matrix as it is multiplied by the input matrix
*
*/
def localRandomMatrix(d: Int, numFeatures: Int): Matrix = {
val randomGenerator = new Random()
val values = Array.fill(numFeatures * d)(randomGenerator.nextGaussian())
Matrices.dense(numFeatures, d, values)
}
/**
* Converts a given input matrix to a bit set representation using random hyperplanes
*
*/
def matrixToBitSet(inputMatrix: IndexedRowMatrix, localRandomMatrix: Matrix): RDD[Signature] = {
val bitSets = inputMatrix.multiply(localRandomMatrix).rows.map {
indexedRow =>
(indexedRow.index, vectorToBitSet(indexedRow.vector))
}
val originalVectors = inputMatrix.rows.map { row => (row.index, row.vector) }
bitSets.join(originalVectors).map {
case (id, (bitSet, vector)) =>
Signature(id, vector, bitSet)
}
}
/**
* Converts a given input matrix to a bit set representation using random hyperplanes
*
*/
def matrixToBitSetSparse(inputMatrix: IndexedRowMatrix, localRandomMatrix: Matrix): RDD[SparseSignature] = {
inputMatrix.multiply(localRandomMatrix).rows.map {
indexedRow: IndexedRow =>
val bitSet = vectorToBitSet(indexedRow.vector)
SparseSignature(indexedRow.index, bitSet)
}
}
/**
* Converts a vector to a bit set by replacing all values of x with sign(x)
*
*/
def vectorToBitSet(vector: Vector): BitSet = {
val bitSet = new BitSet(vector.size)
vector.toArray.zipWithIndex.map {
case ((value: Double, index: Int)) =>
if (math.signum(value) > 0)
bitSet.set(index)
}
bitSet
}
/**
* Approximates the cosine distance of two bit sets using their hamming
* distance
*
*/
def hammingToCosine(hammingDistance: Int, d: Double): Double = {
val pr = 1.0 - (hammingDistance / d)
math.cos((1.0 - pr) * math.Pi)
}
/**
* Returns the hamming distance between two bit vectors
*
*/
def hamming(vec1: BitSet, vec2: BitSet): Int = {
(vec1 ^ vec2).cardinality()
}
/**
* Compares two bit sets for their equality
*
*/
def bitSetIsEqual(vec1: BitSet, vec2: BitSet): Boolean = {
hamming(vec1, vec2) == 0
}
/**
* Take distinct matrix entry values based on the indices only.
* The actual values are discarded.
*
*/
def distinct(matrix: RDD[MatrixEntry]): RDD[MatrixEntry] = {
matrix.keyBy(m => (m.i, m.j)).reduceByKey((x, y) => x).values
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment