12
12
from ...utils import (
13
13
unpack_archive_with_buffer ,
14
14
save_multiple_parts_file ,
15
- create_save_path ,
15
+ get_or_create_path ,
16
16
drop_nan_by_y_index ,
17
17
)
18
18
from ...log import get_module_logger , TimeInspector
@@ -117,10 +117,7 @@ def __init__(
117
117
raise NotImplementedError ("optimizer {} is not supported!" .format (optimizer ))
118
118
119
119
def pretrain_fn (self , dataset = DatasetH , pretrain_file = "./pretrain/best.model" ):
120
- # make a directory if pretrian director does not exist
121
- if pretrain_file .startswith ("./pretrain" ) and not os .path .exists ("pretrain" ):
122
- self .logger .info ("make folder to store model..." )
123
- os .makedirs ("pretrain" )
120
+ get_or_create_path (pretrain_file )
124
121
125
122
[df_train , df_valid ] = dataset .prepare (
126
123
["pretrain" , "pretrain_validation" ],
@@ -181,6 +178,7 @@ def fit(
181
178
df_train .fillna (df_train .mean (), inplace = True )
182
179
x_train , y_train = df_train ["feature" ], df_train ["label" ]
183
180
x_valid , y_valid = df_valid ["feature" ], df_valid ["label" ]
181
+ save_path = get_or_create_path (save_path )
184
182
185
183
stop_steps = 0
186
184
train_loss = 0
@@ -207,12 +205,16 @@ def fit(
207
205
best_score = val_score
208
206
stop_steps = 0
209
207
best_epoch = epoch_idx
208
+ best_param = copy .deepcopy (self .tabnet_model .state_dict ())
210
209
else :
211
210
stop_steps += 1
212
211
if stop_steps >= self .early_stop :
213
212
self .logger .info ("early stop" )
214
213
break
214
+
215
215
self .logger .info ("best score: %.6lf @ %d" % (best_score , best_epoch ))
216
+ self .tabnet_model .load_state_dict (best_param )
217
+ torch .save (best_param , save_path )
216
218
217
219
def predict (self , dataset ):
218
220
if not self .fitted :
0 commit comments