diaryTraining.py 743 Bytes
Newer Older
张彦钊's avatar
张彦钊 committed
1 2
import xlearn as xl
from config import *
3

4

张彦钊's avatar
张彦钊 committed
5 6 7
def train():
    print("Start training")
    ffm_model = xl.create_ffm()
张彦钊's avatar
张彦钊 committed
8 9
    ffm_model.setTrain(DIRECTORY_PATH + "train_ffm_data.csv")
    ffm_model.setValidate(DIRECTORY_PATH + "validation_ffm_data.csv")
10
    # log保存路径,如果不加这个参数,日志默认保存在/temp路径下,不符合规范
11
    param = {'task': 'binary', 'lr': lr, 'lambda': l2_lambda, 'metric': 'auc',"log":DIRECTORY_PATH+"result"}
张彦钊's avatar
张彦钊 committed
12

13
    ffm_model.fit(param, DIRECTORY_PATH + "train/model.out")
张彦钊's avatar
张彦钊 committed
14 15

    print("predicting")
张彦钊's avatar
张彦钊 committed
16
    ffm_model.setTest(DIRECTORY_PATH + "test_ffm_data.csv")
张彦钊's avatar
张彦钊 committed
17
    ffm_model.setSigmoid()
18
    ffm_model.predict(DIRECTORY_PATH + "train/model.out",DIRECTORY_PATH + "test_set_predict_output.txt")
19

20