Commit ada16bb5 authored by 赵威's avatar 赵威

item2vec

parent 0254dd52
...@@ -4,6 +4,7 @@ import pandas as pd ...@@ -4,6 +4,7 @@ import pandas as pd
base_dir = os.getcwd() base_dir = os.getcwd()
DATA_PATH = os.path.join(base_dir, "_data") DATA_PATH = os.path.join(base_dir, "_data")
MODEL_PATH = os.path.join(base_dir, "_model")
def remove_file(path): def remove_file(path):
......
import os import os
import sys
from collections import defaultdict from collections import defaultdict
from datetime import date, timedelta
from pyspark import SparkConf sys.path.append(os.path.realpath("."))
from pyspark.sql import SparkSession
from pytispark import pytispark as pti
base_dir = os.getcwd() from utils.date import get_ndays_before_no_minus, get_ndays_before_with_format
print("base_dir: " + base_dir) from utils.files import DATA_PATH
data_dir = os.path.join(base_dir, "_data") from utils.spark import get_spark
def get_ndays_before_with_format(n, format):
yesterday = (date.today() + timedelta(days=-n)).strftime(format)
return yesterday
def get_ndays_before_no_minus(n):
return get_ndays_before_with_format(n, "%Y%m%d")
def get_spark(app_name=""):
sparkConf = SparkConf()
sparkConf.set("spark.sql.crossJoin.enabled", True)
sparkConf.set("spark.debug.maxToStringFields", "100")
sparkConf.set("spark.tispark.plan.allow_index_double_read", False)
sparkConf.set("spark.tispark.plan.allow_index_read", True)
sparkConf.set("spark.hive.mapred.supports.subdirectories", True)
sparkConf.set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", True)
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConf.set("mapreduce.output.fileoutputformat.compress", False)
sparkConf.set("mapreduce.map.output.compress", False)
spark = SparkSession.builder.config(conf=sparkConf).config(
"spark.sql.extensions",
"org.apache.spark.sql.TiExtensions").config("spark.tispark.pd.addresses",
"172.16.40.170:2379").appName(app_name).enableHiveSupport().getOrCreate()
sc = spark.sparkContext
sc.setLogLevel("ERROR")
# sc.addPyFile("/srv/apps/strategy_embedding/utils/date.py")
ti = pti.TiContext(spark)
ti.tidbMapDatabase("jerry_test")
return spark
def get_tracate_click_data(spark, start, end): def get_tracate_click_data(spark, start, end):
...@@ -164,7 +130,7 @@ if __name__ == "__main__": ...@@ -164,7 +130,7 @@ if __name__ == "__main__":
res_dict = get_device_click_tractate_ids_dict(click_df) res_dict = get_device_click_tractate_ids_dict(click_df)
with open(os.path.join(data_dir, "click_tractate_ids.csv"), "w") as f: with open(os.path.join(DATA_PATH, "click_tractate_ids.csv"), "w") as f:
for (k, v) in res_dict.items(): for (k, v) in res_dict.items():
if v: if v:
f.write("{}|{}\n".format(k, ",".join([str(x) for x in v]))) f.write("{}|{}\n".format(k, ",".join([str(x) for x in v])))
......
import multiprocessing import multiprocessing
import os import os
import sys
import time import time
import traceback import traceback
sys.path.append(os.path.realpath("."))
from gensim.models import Word2Vec, word2vec from gensim.models import Word2Vec, word2vec
from gm_rpcd.all import bind from gm_rpcd.all import bind
from utils.es import es_scan from utils.es import es_scan
from utils.message import send_msg_to_dingtalk from utils.message import send_msg_to_dingtalk
from utils.file import DATA_PATH, MODEL_PATH
base_dir = os.getcwd()
print("base_dir: " + base_dir)
model_dir = os.path.join(base_dir, "_models")
data_dir = os.path.join(base_dir, "_data")
model_output_name = "w2v_model" model_output_name = "w2v_model"
model_path = os.path.join(model_dir, model_output_name) model_path = os.path.join(MODEL_PATH, model_output_name)
try: try:
WORD2VEC_MODEL = word2vec.Word2Vec.load(model_path) WORD2VEC_MODEL = word2vec.Word2Vec.load(model_path)
except Exception as e: except Exception as e:
print(e) print(e)
tracate_click_ids_model_name = "tractate_click_ids_item2vec_model" tracate_click_ids_model_name = "tractate_click_ids_item2vec_model"
tractate_click_ids_model_path = os.path.join(model_dir, tracate_click_ids_model_name) tractate_click_ids_model_path = os.path.join(MODEL_PATH, tracate_click_ids_model_name)
try: try:
TRACTATE_CLICK_IDS_MODEL = word2vec.Word2Vec.load(tractate_click_ids_model_path) TRACTATE_CLICK_IDS_MODEL = word2vec.Word2Vec.load(tractate_click_ids_model_path)
except Exception as e: except Exception as e:
...@@ -39,11 +38,11 @@ class W2vSentences: ...@@ -39,11 +38,11 @@ class W2vSentences:
def w2v_train(f_name, model_output_name): def w2v_train(f_name, model_output_name):
input_file = os.path.join(data_dir, f_name) input_file = os.path.join(DATA_PATH, f_name)
print("input: " + input_file) print("input: " + input_file)
sentences = W2vSentences(input_file) sentences = W2vSentences(input_file)
w2v_model = word2vec.Word2Vec(sentences, min_count=2, workers=2, size=100, window=10) w2v_model = word2vec.Word2Vec(sentences, min_count=2, workers=2, size=100, window=10)
model_path = os.path.join(model_dir, model_output_name) model_path = os.path.join(MODEL_PATH, model_output_name)
print("output: " + model_path) print("output: " + model_path)
w2v_model.save(model_path) w2v_model.save(model_path)
...@@ -92,7 +91,7 @@ def projects_item2vec(score_limit=5): ...@@ -92,7 +91,7 @@ def projects_item2vec(score_limit=5):
def save_clicked_tractate_ids_item2vec(): def save_clicked_tractate_ids_item2vec():
click_ids = [] click_ids = []
with open(os.path.join(data_dir, "click_tractate_ids.csv"), "r") as f: with open(os.path.join(DATA_PATH, "click_tractate_ids.csv"), "r") as f:
data = f.readlines() data = f.readlines()
for i in data: for i in data:
tmp = i.split("|") tmp = i.split("|")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment