Commit ada16bb5 authored by 赵威's avatar 赵威

item2vec

parent 0254dd52
......@@ -4,6 +4,7 @@ import pandas as pd
base_dir = os.getcwd()
DATA_PATH = os.path.join(base_dir, "_data")
MODEL_PATH = os.path.join(base_dir, "_model")
def remove_file(path):
......
import os
import sys
from collections import defaultdict
from datetime import date, timedelta
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pytispark import pytispark as pti
sys.path.append(os.path.realpath("."))
base_dir = os.getcwd()
print("base_dir: " + base_dir)
data_dir = os.path.join(base_dir, "_data")
def get_ndays_before_with_format(n, format):
yesterday = (date.today() + timedelta(days=-n)).strftime(format)
return yesterday
def get_ndays_before_no_minus(n):
return get_ndays_before_with_format(n, "%Y%m%d")
def get_spark(app_name=""):
sparkConf = SparkConf()
sparkConf.set("spark.sql.crossJoin.enabled", True)
sparkConf.set("spark.debug.maxToStringFields", "100")
sparkConf.set("spark.tispark.plan.allow_index_double_read", False)
sparkConf.set("spark.tispark.plan.allow_index_read", True)
sparkConf.set("spark.hive.mapred.supports.subdirectories", True)
sparkConf.set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", True)
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConf.set("mapreduce.output.fileoutputformat.compress", False)
sparkConf.set("mapreduce.map.output.compress", False)
spark = SparkSession.builder.config(conf=sparkConf).config(
"spark.sql.extensions",
"org.apache.spark.sql.TiExtensions").config("spark.tispark.pd.addresses",
"172.16.40.170:2379").appName(app_name).enableHiveSupport().getOrCreate()
sc = spark.sparkContext
sc.setLogLevel("ERROR")
# sc.addPyFile("/srv/apps/strategy_embedding/utils/date.py")
ti = pti.TiContext(spark)
ti.tidbMapDatabase("jerry_test")
return spark
from utils.date import get_ndays_before_no_minus, get_ndays_before_with_format
from utils.files import DATA_PATH
from utils.spark import get_spark
def get_tracate_click_data(spark, start, end):
......@@ -164,7 +130,7 @@ if __name__ == "__main__":
res_dict = get_device_click_tractate_ids_dict(click_df)
with open(os.path.join(data_dir, "click_tractate_ids.csv"), "w") as f:
with open(os.path.join(DATA_PATH, "click_tractate_ids.csv"), "w") as f:
for (k, v) in res_dict.items():
if v:
f.write("{}|{}\n".format(k, ",".join([str(x) for x in v])))
......
import multiprocessing
import os
import sys
import time
import traceback
sys.path.append(os.path.realpath("."))
from gensim.models import Word2Vec, word2vec
from gm_rpcd.all import bind
from utils.es import es_scan
from utils.message import send_msg_to_dingtalk
base_dir = os.getcwd()
print("base_dir: " + base_dir)
model_dir = os.path.join(base_dir, "_models")
data_dir = os.path.join(base_dir, "_data")
from utils.file import DATA_PATH, MODEL_PATH
model_output_name = "w2v_model"
model_path = os.path.join(model_dir, model_output_name)
model_path = os.path.join(MODEL_PATH, model_output_name)
try:
WORD2VEC_MODEL = word2vec.Word2Vec.load(model_path)
except Exception as e:
print(e)
tracate_click_ids_model_name = "tractate_click_ids_item2vec_model"
tractate_click_ids_model_path = os.path.join(model_dir, tracate_click_ids_model_name)
tractate_click_ids_model_path = os.path.join(MODEL_PATH, tracate_click_ids_model_name)
try:
TRACTATE_CLICK_IDS_MODEL = word2vec.Word2Vec.load(tractate_click_ids_model_path)
except Exception as e:
......@@ -39,11 +38,11 @@ class W2vSentences:
def w2v_train(f_name, model_output_name):
input_file = os.path.join(data_dir, f_name)
input_file = os.path.join(DATA_PATH, f_name)
print("input: " + input_file)
sentences = W2vSentences(input_file)
w2v_model = word2vec.Word2Vec(sentences, min_count=2, workers=2, size=100, window=10)
model_path = os.path.join(model_dir, model_output_name)
model_path = os.path.join(MODEL_PATH, model_output_name)
print("output: " + model_path)
w2v_model.save(model_path)
......@@ -92,7 +91,7 @@ def projects_item2vec(score_limit=5):
def save_clicked_tractate_ids_item2vec():
click_ids = []
with open(os.path.join(data_dir, "click_tractate_ids.csv"), "r") as f:
with open(os.path.join(DATA_PATH, "click_tractate_ids.csv"), "r") as f:
data = f.readlines()
for i in data:
tmp = i.split("|")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment