from pathlib import Path

import tensorflow as tf
from sklearn.model_selection import train_test_split

from models.esmm.fe import (click_feature_engineering, device_feature_engineering, diary_feature_engineering, join_features,
                            read_csv_data)
from models.esmm.input_fn import build_features, esmm_input_fn
from models.esmm.model import esmm_model_fn, model_export

tf.compat.v1.enable_eager_execution()


def main():
    device_df, diary_df, click_df, conversion_df = read_csv_data(Path("~/Desktop/cvr_data/"))
    device_df = device_feature_engineering(device_df)
    diary_df = diary_feature_engineering(diary_df)
    cc_df = click_feature_engineering(click_df, conversion_df)
    df = join_features(device_df, diary_df, cc_df)

    train_df, test_df = train_test_split(df, test_size=0.2)
    train_df, val_df = train_test_split(train_df, test_size=0.2)

    all_features = build_features(df)

    params = {"feature_columns": all_features, "hidden_units": [32], "learning_rate": 0.1}
    model_path = str(Path("~/Desktop/models/").expanduser())
    model = tf.estimator.Estimator(model_fn=esmm_model_fn, params=params, model_dir=model_path)

    model.train(input_fn=lambda: esmm_input_fn(train_df, shuffle=True), steps=5000)
    model.evaluate(input_fn=lambda: esmm_input_fn(val_df, False), steps=5000)
    model_export(model, all_features, model_path)

    predictions = model.predict(input_fn=lambda: esmm_input_fn(test_df, False))
    print(next(iter(predictions)))


if __name__ == "__main__":
    main()
