Commit 43fd8160 authored by 赵威's avatar 赵威

add train script

parent 7522fe18
if ! ps aux | grep src/train_diary.py | grep -v grep
then
nohup /home/gmuser/.virtualenvs/tf1/bin/python3 /srv/apps/gm_strategy_cvr/src/train_diary.py > ~/ctcvr_diary.log &
fi
if ! ps aux | grep src/train_tractate.py | grep -v grep
then
nohup /home/gmuser/.virtualenvs/tf1/bin/python3 /srv/apps/gm_strategy_cvr/src/train_tractate.py > ~/ctcvr_tractate.log &
fi
......@@ -8,7 +8,7 @@ import tensorflow as tf
from models.esmm.diary_model import model_predict_diary
from models.esmm.fe import device_fe, diary_fe, tractate_fe
from models.esmm.tractate_model import model_predict_tractate
from utils.cache import redis_client2
from utils.cache import get_essm_model_save_path, redis_client2
from utils.grey import recommed_service_category_device_id_by_tail
from utils.portrait import (get_user_portrait_tag3_read_v2, user_portrait_tag3_get_candidate_dict,
user_portrait_tag3_get_candidate_unread_list, user_portrait_tag3_write_ctcvr_data)
......@@ -77,10 +77,16 @@ def main():
tractate_dict = tractate_fe.get_tractate_dict_from_redis()
print("redis data: " + str(len(device_dict)) + " " + str(len(diary_dict)) + " " + str(len(tractate_dict)))
diary_save_path = "/home/gmuser/data/models/diary/1596509008"
diary_save_path = get_essm_model_save_path("diary")
if not diary_save_path:
diary_save_path = "/home/gmuser/data/models/diary/1596509008"
print(diary_save_path + "!!!!!!!!!!!!!!!!!!!!!!!!!!!")
diary_predict_fn = tf.contrib.predictor.from_saved_model(diary_save_path)
tractate_save_path = "/home/gmuser/data/models/tractate/1596509299"
tractate_save_path = get_essm_model_save_path("tractate")
if not tractate_save_path:
tractate_save_path = "/home/gmuser/data/models/tractate/1596509299"
print(tractate_save_path + "!!!!!!!!!!!!!!!!!!!!!!!!!!!")
tractate_predict_fn = tf.contrib.predictor.from_saved_model(tractate_save_path)
device_id = "androidid_a25a1129c0b38f7b"
......
......@@ -13,6 +13,7 @@ from models.esmm.diary_model import PREDICTION_ALL_COLUMNS, model_predict_diary
from models.esmm.fe import click_fe, device_fe, diary_fe, fe
from models.esmm.input_fn import esmm_input_fn
from models.esmm.model import esmm_model_fn, model_export
from utils.cache import set_essm_model_save_path
def main():
......@@ -44,8 +45,8 @@ def main():
all_features = fe.build_features(df, diary_fe.INT_COLUMNS, diary_fe.FLOAT_COLUMNS, diary_fe.CATEGORICAL_COLUMNS)
params = {"feature_columns": all_features, "hidden_units": [64, 32], "learning_rate": 0.1}
model_path = str(Path("~/data/model_tmp/diary/").expanduser())
# if os.path.exists(model_path):
# shutil.rmtree(model_path)
if os.path.exists(model_path):
shutil.rmtree(model_path)
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.allow_growth = True
......@@ -60,6 +61,7 @@ def main():
model_export_path = str(Path("~/data/models/diary").expanduser())
save_path = model_export(model, all_features, model_export_path)
print("save to: " + save_path)
set_essm_model_save_path("diary", save_path)
diary_train_columns = set(diary_fe.INT_COLUMNS + diary_fe.FLOAT_COLUMNS + diary_fe.CATEGORICAL_COLUMNS)
diary_predict_columns = set(PREDICTION_ALL_COLUMNS)
......
......@@ -13,6 +13,7 @@ from models.esmm.fe import click_fe, device_fe, fe, tractate_fe
from models.esmm.input_fn import esmm_input_fn
from models.esmm.model import esmm_model_fn, model_export
from models.esmm.tractate_model import (PREDICTION_ALL_COLUMNS, model_predict_tractate)
from utils.cache import set_essm_model_save_path
def main():
......@@ -41,8 +42,8 @@ def main():
all_features = fe.build_features(df, tractate_fe.INT_COLUMNS, tractate_fe.FLOAT_COLUMNS, tractate_fe.CATEGORICAL_COLUMNS)
params = {"feature_columns": all_features, "hidden_units": [64, 32], "learning_rate": 0.1}
model_path = str(Path("~/data/model_tmp/tractate/").expanduser())
# if os.path.exists(model_path):
# shutil.rmtree(model_path)
if os.path.exists(model_path):
shutil.rmtree(model_path)
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.allow_growth = True
......@@ -57,6 +58,7 @@ def main():
model_export_path = str(Path("~/data/models/tractate/").expanduser())
save_path = model_export(model, all_features, model_export_path)
print("save to: " + save_path)
set_essm_model_save_path("tractate", save_path)
tractate_train_columns = set(tractate_fe.INT_COLUMNS + tractate_fe.FLOAT_COLUMNS + tractate_fe.CATEGORICAL_COLUMNS)
tractate_predict_columns = set(PREDICTION_ALL_COLUMNS)
......
......@@ -5,3 +5,21 @@ redis_client2 = redis.StrictRedis.from_url("redis://:ReDis!GmTx*0aN9@172.16.40.1
redis_client3 = redis.StrictRedis.from_url("redis://:ReDis!GmTx*0aN12@172.16.40.164:6379")
redis_client4 = redis.StrictRedis.from_url("redis://:XfkMCCdWDIU%ls$h@172.16.50.145:6379")
redis_db_client = redis.StrictRedis.from_url("redis://:ReDis!GmTx*0aN14@172.16.40.146:6379")
def _essm_model_save_path(content_type):
return "doris:essm:{}:model_path".format(content_type)
def get_essm_model_save_path(content_type):
key = _essm_model_save_path(content_type)
path = redis_client3.get(key)
if path:
path = str(path, "utf-8")
return path
return None
def set_essm_model_save_path(content_type, path):
key = _essm_model_save_path(content_type)
redis_client3.set(key, path)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment