Commit 76fdbb84 authored by 张彦钊's avatar 张彦钊

delete auc function matplotlib

parent beb3b2e3
...@@ -7,4 +7,4 @@ if __name__ == "__main__": ...@@ -7,4 +7,4 @@ if __name__ == "__main__":
test = pd.read_csv(DIRECTORY_PATH + "test.csv",header = None) test = pd.read_csv(DIRECTORY_PATH + "test.csv",header = None)
test_label = test[0].apply(lambda x: x[0]).values test_label = test[0].apply(lambda x: x[0]).values
predict = pd.read_csv(DIRECTORY_PATH + "model.out",header = None)[0].values predict = pd.read_csv(DIRECTORY_PATH + "model.out",header = None)[0].values
get_roc_curve(test_label,predict,DIRECTORY_PATH+"auc_test.JPG") get_roc_curve(test_label,predict)
import pandas as pd import pandas as pd
from sklearn import metrics from sklearn import metrics
import matplotlib.pyplot as plt # import matplotlib.pyplot as plt
from sklearn.metrics import auc from sklearn.metrics import auc
import argparse import argparse
...@@ -12,7 +12,7 @@ parser.add_argument('output_photo',help='The filename of the output_photo') ...@@ -12,7 +12,7 @@ parser.add_argument('output_photo',help='The filename of the output_photo')
args = parser.parse_args() args = parser.parse_args()
def get_roc_curve(label,pred,output): def get_roc_curve(label,pred):
""" """
计算二分类问题的roc和auc 计算二分类问题的roc和auc
""" """
...@@ -24,15 +24,15 @@ def get_roc_curve(label,pred,output): ...@@ -24,15 +24,15 @@ def get_roc_curve(label,pred,output):
fpr, tpr, thresholds = metrics.roc_curve(y, p) fpr, tpr, thresholds = metrics.roc_curve(y, p)
plt.plot(fpr,tpr,marker = 'o') # plt.plot(fpr,tpr,marker = 'o')
plt.xlabel('False positive rate') # plt.xlabel('False positive rate')
plt.ylabel('True positive rate') # plt.ylabel('True positive rate')
plt.title('roc_cureve') # plt.title("roc_curev")
AUC = auc(fpr, tpr) AUC = auc(fpr, tpr)
AUC = "auc={}".format(AUC) AUC = "auc={}".format(AUC)
plt.text(0.5,0.8,AUC,color='blue',ha='center') # plt.text(0.5,0.8,AUC,color='blue',ha='center')
plt.savefig(output) # # plt.savefig(output)
print(AUC) print(AUC)
except: except:
print("the format of the file must be the n*1") print("the format of the file must be the n*1")
...@@ -41,4 +41,4 @@ def get_roc_curve(label,pred,output): ...@@ -41,4 +41,4 @@ def get_roc_curve(label,pred,output):
if __name__ == "__main__": if __name__ == "__main__":
get_roc_curve(args.test_label,args.test_pred,args.output_photo) get_roc_curve(args.test_label,args.test_pred)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment