-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_config.py
122 lines (114 loc) · 4.38 KB
/
train_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# -*- coding: utf-8 -*-
# @File : train_config.py
# @Author : Hua Guo
# @Disc :
from src.Pipeline.XGBRegressionPipeline import XGBRegressionPipeline
from src.FeatureCreator.FeatureCreator import FeatureCreator
from src.Pipeline.DeepFMPipeline import DeepFMPipeline
from src.Pipeline.XGBClassifierPipeline import XGBClassifierPipeline
from src.Pipeline.XGBRegClaPipeline import XGBRegClaPipeline
from src.Pipeline.XGBoostLR import XGBoostLR
from src.Pipeline.RidgeReg import RidgeReg
from src.config import breast_cancel_traget
from src.config import california_target
debug = False
dir_mark = "iris_cla"
# dir_mark = 'california_deepfm_reg'
if debug:
raw_data_path = 'data/debug'
model_dir = 'model_training/debug'
else:
raw_data_path = 'data/raw_data'
model_dir = 'model_training/'
train_config_detail = {
"v1_0501_xgb_clareg": {
'cla_dir': 'v1_0501_xgb_cla'
, 'reg_dir': 'v1_0501_xgb_reg'
, 'pipeline_class': XGBRegClaPipeline
, 'feature_creator': FeatureCreator
, 'train_valid': True
, 'sparse_features': [
]
, 'dense_features': [
]
# , 'feature_clean_func': clean_map_feature
, 'target_col': breast_cancel_traget
# , 'data_dir_mark': 'v1_0501_clareg'
},
"california_housing_reg": {
"pipeline_class": XGBRegressionPipeline
, 'feature_creator': FeatureCreator
, 'train_valid': True
, 'sparse_features': []
, 'dense_features': [
'MedInc', 'HouseAge', 'AveRooms', 'AveBedrms', 'Population', 'AveOccup',
'Latitude', 'Longitude'
]
# , 'feature_clean_func': clean_feature
, 'target_col': california_target
},
"iris_cla": {
"pipeline_class": XGBClassifierPipeline
, 'feature_creator': FeatureCreator
, 'train_valid': True
, 'sparse_features': [
'binary_feature',
'cate_feature'
]
, 'dense_features':
[
'zero_0', 'zero_1',
'mean_radius', 'mean_texture', 'mean_perimeter', 'mean_area',
'mean_smoothness', 'mean_compactness', 'mean_concavity',
'mean_concave_points', 'mean_symmetry', 'mean_fractal_dimension',
'radius_error', 'texture_error', 'perimeter_error', 'area_error',
'smoothness_error', 'compactness_error', 'concavity_error',
'concave_points_error', 'symmetry_error', 'fractal_dimension_error',
'worst_radius', 'worst_texture', 'worst_perimeter', 'worst_area',
'worst_smoothness', 'worst_compactness', 'worst_concavity',
'worst_concave_points', 'worst_symmetry', 'worst_fractal_dimension']
, 'onehot': [
'binary_feature',
'cate_feature'
]
# , 'feature_clean_func': clean_feature
, 'target_col': breast_cancel_traget
},
"iris_deepfm_cla": {
"pipeline_class": DeepFMPipeline
, 'task': 'binary'
, 'feature_creator': FeatureCreator
, 'epochs': 2
, 'batch_size': 20
, 'dense_to_sparse': True
, 'train_valid': True
, 'sparse_features': []
, 'dense_features': ['mean_radius', 'mean_texture', 'mean_perimeter', 'mean_area',
'mean_smoothness', 'mean_compactness', 'mean_concavity',
'mean_concave_points', 'mean_symmetry', 'mean_fractal_dimension',
'radius_error', 'texture_error', 'perimeter_error', 'area_error',
'smoothness_error', 'compactness_error', 'concavity_error',
'concave_points_error', 'symmetry_error', 'fractal_dimension_error',
'worst_radius', 'worst_texture', 'worst_perimeter', 'worst_area',
'worst_smoothness', 'worst_compactness', 'worst_concavity',
'worst_concave_points', 'worst_symmetry', 'worst_fractal_dimension']
# , 'feature_clean_func': clean_feature
, 'target_col': breast_cancel_traget
},
"california_deepfm_reg": {
"pipeline_class": DeepFMPipeline
, 'task': 'regression'
, 'feature_creator': FeatureCreator
, 'epochs': 2
, 'batch_size': 20
, 'dense_to_sparse': True
, 'train_valid': True
, 'sparse_features': []
, 'dense_features': [
'MedInc', 'HouseAge', 'AveRooms', 'AveBedrms', 'Population', 'AveOccup',
'Latitude', 'Longitude'
]
# , 'feature_clean_func': clean_feature
, 'target_col': california_target
},
}