使用Python与Scikit-learn实现逻辑回归分析

https://towardsdatascience.com/logistic-regression-using-python-sklearn-numpy-mnist-handwriting-recognition-matplotlib-a6b31e2b166a

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import numpy as np
import pandas as pd
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt

input_Cef = pd.read_csv("input_Cef.csv")
input_Cef.head()

X = input_Cef.iloc[:,1:6027]
y = input_Cef["Ceftazidim_S.vs.R"]

# Step 1. Import the model & Splitting Data into Training and Test Sets
from sklearn.linear_model import LogisticRegression
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.20,random_state=0)

# Make an instance of the Model
# all parameters not specified are set to their defaults
# Changing the solver had a minor effect on accuracy, but at least it was a lot faster
logreg = LogisticRegression(solver = 'lbfgs')

# Step 3. Training the model
logreg.fit(X_train, y_train)

#Step 4. Predict labels for new data
y_pred = logreg.predict(X_test)

# Step5: Measuring Model Performance
# accuracy , precision, recall, F1 Score, ROC Curve

## accuracy
score = logreg.score(X_test, y_test)
print(score)


## precision, recall, F1 Score
from sklearn.metrics import classification_report
print(classification_report(y_test, y_pred))

## ROC Curve
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
logit_roc_auc2 = roc_auc_score(y_test, y_pred)
probas = logreg.predict_proba(X_test)[:,1]
fpr2, tpr2, thresholds2 = roc_curve(y_test, probas)
plt.figure()
plt.plot(fpr2, tpr2, label='Logistic Regression for Ceftazidim (area = %0.2f)' % logit_roc_auc2)
#plt.plot([0, 1], [0, 1],'r--')
#plt.xlim([0.0, 1.0])
#plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic')
plt.legend(loc="lower right")
plt.savefig('./Fig2_Log_ROC_Cef.pdf')
plt.show()