Commit f55521ba authored by 赵威's avatar 赵威

try predict

parent 2897f11b
......@@ -25,6 +25,7 @@ def main():
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# device_df, diary_df, click_df, conversion_df = read_csv_data(Path("~/data/cvr_data").expanduser())
device_df, diary_df, click_df, conversion_df = read_csv_data(Path("/srv/apps/node2vec_git/cvr_data/"))
# print(diary_df.sample(1))
device_df = device_feature_engineering(device_df)
......@@ -40,8 +41,8 @@ def main():
all_features = build_features(df)
params = {"feature_columns": all_features, "hidden_units": [64, 32], "learning_rate": 0.1}
model_path = str(Path("~/data/model_tmp/").expanduser())
# if os.path.exists(model_path):
# shutil.rmtree(model_path)
if os.path.exists(model_path):
shutil.rmtree(model_path)
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.allow_growth = True
......@@ -53,9 +54,9 @@ def main():
eval_spec = tf.estimator.EvalSpec(input_fn=lambda: esmm_input_fn(val_df, shuffle=False))
tf.estimator.train_and_evaluate(model, train_spec, eval_spec)
# model.train(input_fn=lambda: esmm_input_fn(train_df, shuffle=True))
# metrics = model.evaluate(input_fn=lambda: esmm_input_fn(val_df, False))
# print("metrics: " + str(metrics))
model.train(input_fn=lambda: esmm_input_fn(train_df, shuffle=True))
metrics = model.evaluate(input_fn=lambda: esmm_input_fn(val_df, False))
print("metrics: " + str(metrics))
model_export_path = str(Path("~/data/models/").expanduser())
save_path = model_export(model, all_features, model_export_path)
......@@ -69,10 +70,6 @@ def main():
predict_fn = tf.contrib.predictor.from_saved_model(save_path)
# for i in range(5):
# test_300 = test_df.sample(300)
# model_predict(test_300, predict_fn)
print("==============================")
# device_id = "861601036552944"
# diary_ids = [
......
......@@ -19,7 +19,8 @@ def read_csv_data(dataset_path):
def get_device_dict_from_redis():
db_key = "cvr:db:device"
# TODO
db_key = "cvr:db:device2"
column_key = db_key + ":column"
columns = str(redis_db_client.get(column_key), "utf-8").split("|")
d = redis_db_client.hgetall(db_key)
......@@ -86,15 +87,18 @@ def device_feature_engineering(df):
device_df["second_positions"] = device_df["second_positions"].apply(lambda d: d if isinstance(d, list) else [])
device_df["projects"] = device_df["projects"].apply(lambda d: d if isinstance(d, list) else [])
device_df["city_first"] = device_df["city_first"].fillna("")
device_df["model_first"] = device_df["model_first"].fillna("")
nullseries = device_df.isnull().sum()
print("device:")
print(nullseries[nullseries > 0])
print(device_df.shape)
device_columns = [
"device_id", "active_type", "active_days", "past_consume_ability_history", "potential_consume_ability_history",
"price_sensitive_history", "first_demands", "second_demands", "first_solutions", "second_solutions", "first_positions",
"second_positions", "projects"
"device_id", "active_type", "active_days", "channel_first", "city_first", "model_first", "past_consume_ability_history",
"potential_consume_ability_history", "price_sensitive_history", "first_demands", "second_demands", "first_solutions",
"second_solutions", "first_positions", "second_positions", "projects"
]
return device_df[device_columns]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment