-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest2.py
More file actions
35 lines (24 loc) · 826 Bytes
/
test2.py
File metadata and controls
35 lines (24 loc) · 826 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import sys
import numpy as np
from scipy.stats import describe
#from scipy.stats.mstats import zscore
from scipy.stats import norm as normal
from sklearn import datasets
from sklearn.metrics import roc_auc_score
from ComputeBart import ComputeBart
digits = datasets.load_digits()
sel_7_vs_9 = (digits.target==7) | (digits.target==9)
x = digits.data[sel_7_vs_9,:]
y = digits.target[sel_7_vs_9]
print '7 vs 9 dataset dims:', x.shape
target = np.array([3.0 if v==9 else -3.0 for v in y])
bart = ComputeBart(regression=False)
result = bart.fit_and_predict(x, target, x)
standard_normal = normal()
probs = np.vectorize(standard_normal.cdf)(result)
target = np.array([True if v==9 else False for v in y], dtype=np.bool)
print 'Targets:'
print target
print 'Probs:'
print probs
print 'AUC =', roc_auc_score(target, probs)