Commit beb3b2e3 authored by 张彦钊's avatar 张彦钊

auc caculation

parent 72f26655
from eda.ml_tools.rocCurve import get_roc_curve
import pandas as pd
from config import *
if __name__ == "__main__":
test = pd.read_csv(DIRECTORY_PATH + "test.csv",header = None)
test_label = test[0].apply(lambda x: x[0]).values
predict = pd.read_csv(DIRECTORY_PATH + "model.out",header = None)[0].values
get_roc_curve(test_label,predict,DIRECTORY_PATH+"auc_test.JPG")
...@@ -3,10 +3,10 @@ import xlearn as xl ...@@ -3,10 +3,10 @@ import xlearn as xl
from config import * from config import *
print("start training") print("Start training")
ffm_model = xl.create_ffm() ffm_model = xl.create_ffm()
ffm_model.setTrain(DIRECTORY_PATH + "data.csv") ffm_model.setTrain(DIRECTORY_PATH + "data.csv")
ffm_model.setValidate(DIRECTORY_PATH + "data.csv") ffm_model.setValidate(DIRECTORY_PATH + "validation.csv")
param = {'task':'binary', 'lr':0.03, param = {'task':'binary', 'lr':0.03,
...@@ -14,9 +14,13 @@ param = {'task':'binary', 'lr':0.03, ...@@ -14,9 +14,13 @@ param = {'task':'binary', 'lr':0.03,
ffm_model.fit(param, DIRECTORY_PATH + "model.out") ffm_model.fit(param, DIRECTORY_PATH + "model.out")
ffm_model.setTest(DIRECTORY_PATH + "data.csv") print("predicting")
ffm_model.setTest(DIRECTORY_PATH + "test.csv")
ffm_model.setSigmoid() ffm_model.setSigmoid()
ffm_model.predict(DIRECTORY_PATH + "model.out", ffm_model.predict(DIRECTORY_PATH + "model.out",
DIRECTORY_PATH + "output.txt") DIRECTORY_PATH + "output.txt")
print("end")
...@@ -21,11 +21,14 @@ def get_roc_curve(label,pred,output): ...@@ -21,11 +21,14 @@ def get_roc_curve(label,pred,output):
pred_label = pd.read_table(pred) pred_label = pd.read_table(pred)
y = test_label.values y = test_label.values
p = pred_label.values p = pred_label.values
fpr, tpr, thresholds = metrics.roc_curve(y, p) fpr, tpr, thresholds = metrics.roc_curve(y, p)
plt.plot(fpr,tpr,marker = 'o') plt.plot(fpr,tpr,marker = 'o')
plt.xlabel('False positive rate') plt.xlabel('False positive rate')
plt.ylabel('True positive rate') plt.ylabel('True positive rate')
plt.title('roc_cureve') plt.title('roc_cureve')
AUC = auc(fpr, tpr) AUC = auc(fpr, tpr)
AUC = "auc={}".format(AUC) AUC = "auc={}".format(AUC)
plt.text(0.5,0.8,AUC,color='blue',ha='center') plt.text(0.5,0.8,AUC,color='blue',ha='center')
...@@ -36,5 +39,6 @@ def get_roc_curve(label,pred,output): ...@@ -36,5 +39,6 @@ def get_roc_curve(label,pred,output):
print("the test_label must be 0 or 1") print("the test_label must be 0 or 1")
print("the test_pred must be at [0,1]") print("the test_pred must be at [0,1]")
if __name__ == "__main__": if __name__ == "__main__":
get_roc_curve(args.test_label,args.test_pred,args.output_photo) get_roc_curve(args.test_label,args.test_pred,args.output_photo)
...@@ -10,15 +10,18 @@ exposure, click, click_device_id = fetch_data( ...@@ -10,15 +10,18 @@ exposure, click, click_device_id = fetch_data(
# 求曝光表和点击表的差集合 # 求曝光表和点击表的差集合
print("曝光表处理前的样本个数") print("曝光表处理前的样本个数")
print(exposure.shape) print(exposure.shape)
exposure = exposure.append(click) exposure = exposure.append(click)
exposure = exposure.append(click) exposure = exposure.append(click)
subset = click.columns.tolist() subset = click.columns.tolist()
exposure = exposure.drop_duplicates(subset=subset,keep=False) exposure = exposure.drop_duplicates(subset=subset,keep=False)
print("差集后曝光表个数") print("差集后曝光表个数")
print(exposure.shape) print(exposure.shape)
exposure = exposure.loc[exposure["device_id"].isin(click_device_id)] exposure = exposure.loc[exposure["device_id"].isin(click_device_id)]
print("去除未点击用户后曝光表个数") print("去除未点击用户后曝光表个数")
print(exposure.shape) print(exposure.shape)
# 打标签 # 打标签
click["y"] = 1 click["y"] = 1
exposure["y"] = 0 exposure["y"] = 0
...@@ -65,6 +68,7 @@ test = data.loc[:test_number] ...@@ -65,6 +68,7 @@ test = data.loc[:test_number]
print("测试集大小") print("测试集大小")
print(test.shape[0]) print(test.shape[0])
test.to_csv(DIRECTORY_PATH + "test.csv",index = False,header = None) test.to_csv(DIRECTORY_PATH + "test.csv",index = False,header = None)
validation = data.loc[(test_number+1):(test_number+validation_number)] validation = data.loc[(test_number+1):(test_number+validation_number)]
print("验证集大小") print("验证集大小")
print(validation.shape[0]) print(validation.shape[0])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment