Word2vec.scala 1.3 KB
package com.gmei

import org.apache.spark.SparkContext
import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel}
import org.apache.spark.rdd.RDD

object Word2vec extends Serializable {
  var context: SparkContext = null
  var word2vec = new Word2Vec()
  var model: Word2VecModel = null

  def setup(context: SparkContext, param: Main.Params): this.type = {
    this.context = context
    /**
      * model = sg
      * update = hs
      */
    word2vec.setLearningRate(param.lr)
      .setNumIterations(param.iter)
      .setNumPartitions(param.numPartition)
      .setMinCount(0)
      .setVectorSize(param.dim)

    val word2vecWindowField = word2vec.getClass.getDeclaredField("org$apache$spark$mllib$feature$Word2Vec$$window")
    word2vecWindowField.setAccessible(true)
    word2vecWindowField.setInt(word2vec, param.window)

    this
  }

  def read(path: String): RDD[Iterable[String]] = {
    context.textFile(path).repartition(200).map(_.split("\\s").toSeq)
  }

  def fit(input: RDD[Iterable[String]]): this.type = {
    model = word2vec.fit(input)

    this
  }

  def save(outputPath: String): this.type = {
    model.save(context, s"$outputPath.bin")
    this
  }

  def load(path: String): this.type = {
    model = Word2VecModel.load(context, path)

    this
  }

  def getVectors = this.model.getVectors

}