Commit beb3b2e3 authored by 张彦钊's avatar 张彦钊

auc caculation

parent 72f26655
from eda.ml_tools.rocCurve import get_roc_curve
import pandas as pd
from config import *
if __name__ == "__main__":
test = pd.read_csv(DIRECTORY_PATH + "test.csv",header = None)
test_label = test[0].apply(lambda x: x[0]).values
predict = pd.read_csv(DIRECTORY_PATH + "model.out",header = None)[0].values
get_roc_curve(test_label,predict,DIRECTORY_PATH+"auc_test.JPG")
......@@ -3,10 +3,10 @@ import xlearn as xl
from config import *
print("start training")
print("Start training")
ffm_model = xl.create_ffm()
ffm_model.setTrain(DIRECTORY_PATH + "data.csv")
ffm_model.setValidate(DIRECTORY_PATH + "data.csv")
ffm_model.setValidate(DIRECTORY_PATH + "validation.csv")
param = {'task':'binary', 'lr':0.03,
......@@ -14,9 +14,13 @@ param = {'task':'binary', 'lr':0.03,
ffm_model.fit(param, DIRECTORY_PATH + "model.out")
ffm_model.setTest(DIRECTORY_PATH + "data.csv")
print("predicting")
ffm_model.setTest(DIRECTORY_PATH + "test.csv")
ffm_model.setSigmoid()
ffm_model.predict(DIRECTORY_PATH + "model.out",
DIRECTORY_PATH + "output.txt")
print("end")
......@@ -21,11 +21,14 @@ def get_roc_curve(label,pred,output):
pred_label = pd.read_table(pred)
y = test_label.values
p = pred_label.values
fpr, tpr, thresholds = metrics.roc_curve(y, p)
plt.plot(fpr,tpr,marker = 'o')
plt.xlabel('False positive rate')
plt.ylabel('True positive rate')
plt.title('roc_cureve')
AUC = auc(fpr, tpr)
AUC = "auc={}".format(AUC)
plt.text(0.5,0.8,AUC,color='blue',ha='center')
......@@ -36,5 +39,6 @@ def get_roc_curve(label,pred,output):
print("the test_label must be 0 or 1")
print("the test_pred must be at [0,1]")
if __name__ == "__main__":
get_roc_curve(args.test_label,args.test_pred,args.output_photo)
......@@ -10,15 +10,18 @@ exposure, click, click_device_id = fetch_data(
# 求曝光表和点击表的差集合
print("曝光表处理前的样本个数")
print(exposure.shape)
exposure = exposure.append(click)
exposure = exposure.append(click)
subset = click.columns.tolist()
exposure = exposure.drop_duplicates(subset=subset,keep=False)
print("差集后曝光表个数")
print(exposure.shape)
exposure = exposure.loc[exposure["device_id"].isin(click_device_id)]
print("去除未点击用户后曝光表个数")
print(exposure.shape)
# 打标签
click["y"] = 1
exposure["y"] = 0
......@@ -65,6 +68,7 @@ test = data.loc[:test_number]
print("测试集大小")
print(test.shape[0])
test.to_csv(DIRECTORY_PATH + "test.csv",index = False,header = None)
validation = data.loc[(test_number+1):(test_number+validation_number)]
print("验证集大小")
print(validation.shape[0])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment