Commit ef5008d2 authored by 高雅喆's avatar 高雅喆

add eda/ml_tools/roc_curve.py

parent d61a97b5
import pandas as pd
from sklearn import metrics
import matplotlib.pyplot as plt
from sklearn.metrics import auc
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('test_label',help='The filename of the test_label')
parser.add_argument('test_pred',help='The filename of the test_pred')
parser.add_argument('output_photo',help='The filename of the output_photo')
args = parser.parse_args()
def get_roc_curve(label,pred,output):
"""
计算二分类问题的roc和auc
"""
try:
test_label = pd.read_table(label)
pred_label = pd.read_table(pred)
y = test_label.values
p = pred_label.values
fpr, tpr, thresholds = metrics.roc_curve(y, p)
plt.plot(fpr,tpr,marker = 'o')
plt.xlabel('False positive rate')
plt.ylabel('True positive rate')
plt.title('roc_cureve')
AUC = auc(fpr, tpr)
AUC = "auc={}".format(AUC)
plt.text(0.5,0.8,AUC,color='blue',ha='center')
plt.savefig(output)
print(AUC)
except:
print("the format of the file must be the n*1")
print("the test_label must be 0 or 1")
print("the test_pred must be at [0,1]")
if __name__ == "__main__":
get_roc_curve(args.test_label,args.test_pred,args.output_photo)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment