forked from jjfeng/spinn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfit_trees.py
100 lines (91 loc) · 3.02 KB
/
fit_trees.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import sys
import time
import argparse
import pickle
import numpy as np
import pandas as pd
import logging
from scipy.stats import pearsonr
from common import *
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import RandomForestRegressor
def parse_args():
''' parse command line arguments '''
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--seed',
type=int,
help='seed',
default=1)
parser.add_argument('--max-depth',
type=int,
default=10)
parser.add_argument('--num-trees',
type=int,
default=1000)
parser.add_argument('--num-jobs',
type=int,
default=10)
parser.add_argument('--data-index-file',
type=str,
default=None)
parser.add_argument('--data-X-file',
type=str)
parser.add_argument('--data-y-file',
type=str)
parser.add_argument('--data-file',
type=str,
default="_output/data.pkl")
parser.add_argument('--log-file',
type=str,
default="_output/log_trees.txt")
parser.add_argument('--out-file',
type=str,
default="_output/fitted_trees.csv")
parser.add_argument('--data-classes',
type=int,
default=0)
args = parser.parse_args()
return args
def main(args=sys.argv[1:]):
args = parse_args()
logging.basicConfig(format="%(message)s", filename=args.log_file, level=logging.DEBUG)
np.random.seed(args.seed)
logging.info(args)
train_data, test_data = read_input_data(args)
if args.data_classes == 0:
regr = RandomForestRegressor(
max_depth=args.max_depth,
random_state=args.seed,
n_estimators=args.num_trees,
max_features=0.3,
n_jobs=args.num_jobs)
else:
train_data.y = train_data.y.astype(int)
train_data.y_true = train_data.y_true.astype(int)
test_data.y = test_data.y.astype(int)
test_data.y_true = test_data.y_true.astype(int)
regr = RandomForestClassifier(
max_depth=args.max_depth,
random_state=args.seed,
n_estimators=args.num_trees,
max_features=0.3,
n_jobs=args.num_jobs)
regr.fit(train_data.x, train_data.y.ravel())
logging.info("FEATURE IMPORT")
sort_idxs = np.argsort(regr.feature_importances_)
for idx in sort_idxs[-50:]:
importance = regr.feature_importances_[idx]
logging.info("%d: %f", idx, importance)
y_pred = regr.predict(test_data.x)
if args.data_classes == 0:
test_loss = get_regress_loss(y_pred, test_data.y_true)
logging.info("pearsonr %s", pearsonr(y_pred.ravel(), test_data.y_true.ravel()))
print("PRED", y_pred)
print(test_data.y_true)
else:
test_loss = 1 - get_classification_accuracy(y_pred, test_data.y_true)
result = pd.DataFrame({
"test_loss": [test_loss]})
result.T.to_csv(args.out_file)
if __name__ == "__main__":
main(sys.argv[1:])