Word2vec.scala 1.3 KB
Newer Older
高雅喆's avatar
高雅喆 committed
1 2 3 4 5 6 7 8 9 10
package com.gmei

import org.apache.spark.SparkContext
import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel}
import org.apache.spark.rdd.RDD

object Word2vec extends Serializable {
  var context: SparkContext = null
  var word2vec = new Word2Vec()
  var model: Word2VecModel = null
11

高雅喆's avatar
高雅喆 committed
12 13 14 15 16 17 18
  def setup(context: SparkContext, param: Main.Params): this.type = {
    this.context = context
    /**
      * model = sg
      * update = hs
      */
    word2vec.setLearningRate(param.lr)
19 20 21 22
      .setNumIterations(param.iter)
      .setNumPartitions(param.numPartition)
      .setMinCount(0)
      .setVectorSize(param.dim)
高雅喆's avatar
高雅喆 committed
23 24 25 26

    val word2vecWindowField = word2vec.getClass.getDeclaredField("org$apache$spark$mllib$feature$Word2Vec$$window")
    word2vecWindowField.setAccessible(true)
    word2vecWindowField.setInt(word2vec, param.window)
27

高雅喆's avatar
高雅喆 committed
28 29
    this
  }
30

高雅喆's avatar
高雅喆 committed
31 32 33
  def read(path: String): RDD[Iterable[String]] = {
    context.textFile(path).repartition(200).map(_.split("\\s").toSeq)
  }
34

高雅喆's avatar
高雅喆 committed
35 36
  def fit(input: RDD[Iterable[String]]): this.type = {
    model = word2vec.fit(input)
37

高雅喆's avatar
高雅喆 committed
38 39
    this
  }
40

高雅喆's avatar
高雅喆 committed
41 42 43
  def save(outputPath: String): this.type = {
    model.save(context, s"$outputPath.bin")
    this
44 45
  }

高雅喆's avatar
高雅喆 committed
46 47
  def load(path: String): this.type = {
    model = Word2VecModel.load(context, path)
48

高雅喆's avatar
高雅喆 committed
49 50
    this
  }
51

高雅喆's avatar
高雅喆 committed
52
  def getVectors = this.model.getVectors
53

高雅喆's avatar
高雅喆 committed
54 55
}