Skip to content

Commit f6b019d

Browse files
authored
Merge pull request microsoft#328 from D-X-Y/fshare
Move get_path to get_or_create_path, use the best model of SFM / TabNet
2 parents e8beaa5 + e626264 commit f6b019d

12 files changed

+47
-41
lines changed

qlib/contrib/model/pytorch_alstm.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ...utils import (
1515
unpack_archive_with_buffer,
1616
save_multiple_parts_file,
17-
create_save_path,
17+
get_or_create_path,
1818
drop_nan_by_y_index,
1919
)
2020
from ...log import get_module_logger, TimeInspector
@@ -230,8 +230,7 @@ def fit(
230230
x_train, y_train = df_train["feature"], df_train["label"]
231231
x_valid, y_valid = df_valid["feature"], df_valid["label"]
232232

233-
if save_path == None:
234-
save_path = create_save_path(save_path)
233+
save_path = get_or_create_path(save_path)
235234
stop_steps = 0
236235
train_loss = 0
237236
best_score = -np.inf

qlib/contrib/model/pytorch_alstm_ts.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ...utils import (
1515
unpack_archive_with_buffer,
1616
save_multiple_parts_file,
17-
create_save_path,
17+
get_or_create_path,
1818
drop_nan_by_y_index,
1919
)
2020
from ...log import get_module_logger, TimeInspector
@@ -220,8 +220,7 @@ def fit(
220220
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
221221
)
222222

223-
if save_path == None:
224-
save_path = create_save_path(save_path)
223+
save_path = get_or_create_path(save_path)
225224

226225
stop_steps = 0
227226
train_loss = 0

qlib/contrib/model/pytorch_gats.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ...utils import (
1515
unpack_archive_with_buffer,
1616
save_multiple_parts_file,
17-
create_save_path,
17+
get_or_create_path,
1818
drop_nan_by_y_index,
1919
)
2020
from ...log import get_module_logger, TimeInspector
@@ -248,8 +248,7 @@ def fit(
248248
x_train, y_train = df_train["feature"], df_train["label"]
249249
x_valid, y_valid = df_valid["feature"], df_valid["label"]
250250

251-
if save_path == None:
252-
save_path = create_save_path(save_path)
251+
save_path = get_or_create_path(save_path)
253252
stop_steps = 0
254253
best_score = -np.inf
255254
best_epoch = 0

qlib/contrib/model/pytorch_gats_ts.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ...utils import (
1515
unpack_archive_with_buffer,
1616
save_multiple_parts_file,
17-
create_save_path,
17+
get_or_create_path,
1818
drop_nan_by_y_index,
1919
)
2020
from ...log import get_module_logger, TimeInspector
@@ -264,8 +264,7 @@ def fit(
264264
train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs, drop_last=True)
265265
valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs, drop_last=True)
266266

267-
if save_path == None:
268-
save_path = create_save_path(save_path)
267+
save_path = get_or_create_path(save_path)
269268

270269
stop_steps = 0
271270
train_loss = 0

qlib/contrib/model/pytorch_gru.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ...utils import (
1515
unpack_archive_with_buffer,
1616
save_multiple_parts_file,
17-
create_save_path,
17+
get_or_create_path,
1818
drop_nan_by_y_index,
1919
)
2020
from ...log import get_module_logger, TimeInspector
@@ -230,8 +230,7 @@ def fit(
230230
x_train, y_train = df_train["feature"], df_train["label"]
231231
x_valid, y_valid = df_valid["feature"], df_valid["label"]
232232

233-
if save_path == None:
234-
save_path = create_save_path(save_path)
233+
save_path = get_or_create_path(save_path)
235234
stop_steps = 0
236235
train_loss = 0
237236
best_score = -np.inf

qlib/contrib/model/pytorch_gru_ts.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ...utils import (
1515
unpack_archive_with_buffer,
1616
save_multiple_parts_file,
17-
create_save_path,
17+
get_or_create_path,
1818
drop_nan_by_y_index,
1919
)
2020
from ...log import get_module_logger, TimeInspector
@@ -220,8 +220,7 @@ def fit(
220220
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
221221
)
222222

223-
if save_path == None:
224-
save_path = create_save_path(save_path)
223+
save_path = get_or_create_path(save_path)
225224

226225
stop_steps = 0
227226
train_loss = 0

qlib/contrib/model/pytorch_lstm.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ...utils import (
1515
unpack_archive_with_buffer,
1616
save_multiple_parts_file,
17-
create_save_path,
17+
get_or_create_path,
1818
drop_nan_by_y_index,
1919
)
2020
from ...log import get_module_logger, TimeInspector
@@ -226,8 +226,7 @@ def fit(
226226
x_train, y_train = df_train["feature"], df_train["label"]
227227
x_valid, y_valid = df_valid["feature"], df_valid["label"]
228228

229-
if save_path == None:
230-
save_path = create_save_path(save_path)
229+
save_path = get_or_create_path(save_path)
231230
stop_steps = 0
232231
train_loss = 0
233232
best_score = -np.inf

qlib/contrib/model/pytorch_lstm_ts.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ...utils import (
1515
unpack_archive_with_buffer,
1616
save_multiple_parts_file,
17-
create_save_path,
17+
get_or_create_path,
1818
drop_nan_by_y_index,
1919
)
2020
from ...log import get_module_logger, TimeInspector
@@ -216,8 +216,7 @@ def fit(
216216
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
217217
)
218218

219-
if save_path == None:
220-
save_path = create_save_path(save_path)
219+
save_path = get_or_create_path(save_path)
221220

222221
stop_steps = 0
223222
train_loss = 0

qlib/contrib/model/pytorch_nn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ...model.base import Model
2020
from ...data.dataset import DatasetH
2121
from ...data.dataset.handler import DataHandlerLP
22-
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
22+
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path, drop_nan_by_y_index
2323
from ...log import get_module_logger, TimeInspector
2424
from ...workflow import R
2525

@@ -176,7 +176,7 @@ def fit(
176176
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
177177
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
178178

179-
save_path = create_save_path(save_path)
179+
save_path = get_or_create_path(save_path)
180180
stop_steps = 0
181181
train_loss = 0
182182
best_loss = np.inf

qlib/contrib/model/pytorch_sfm.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ...utils import (
1414
unpack_archive_with_buffer,
1515
save_multiple_parts_file,
16-
create_save_path,
16+
get_or_create_path,
1717
drop_nan_by_y_index,
1818
)
1919
from ...log import get_module_logger, TimeInspector
@@ -380,6 +380,7 @@ def fit(
380380
x_train, y_train = df_train["feature"], df_train["label"]
381381
x_valid, y_valid = df_valid["feature"], df_valid["label"]
382382

383+
save_path = get_or_create_path(save_path)
383384
stop_steps = 0
384385
train_loss = 0
385386
best_score = -np.inf
@@ -412,7 +413,10 @@ def fit(
412413
if stop_steps >= self.early_stop:
413414
self.logger.info("early stop")
414415
break
416+
415417
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
418+
self.sfm_model.load_state_dict(best_param)
419+
torch.save(best_param, save_path)
416420
if self.device != "cpu":
417421
torch.cuda.empty_cache()
418422

qlib/contrib/model/pytorch_tabnet.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ...utils import (
1313
unpack_archive_with_buffer,
1414
save_multiple_parts_file,
15-
create_save_path,
15+
get_or_create_path,
1616
drop_nan_by_y_index,
1717
)
1818
from ...log import get_module_logger, TimeInspector
@@ -117,10 +117,7 @@ def __init__(
117117
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
118118

119119
def pretrain_fn(self, dataset=DatasetH, pretrain_file="./pretrain/best.model"):
120-
# make a directory if pretrian director does not exist
121-
if pretrain_file.startswith("./pretrain") and not os.path.exists("pretrain"):
122-
self.logger.info("make folder to store model...")
123-
os.makedirs("pretrain")
120+
get_or_create_path(pretrain_file)
124121

125122
[df_train, df_valid] = dataset.prepare(
126123
["pretrain", "pretrain_validation"],
@@ -181,6 +178,7 @@ def fit(
181178
df_train.fillna(df_train.mean(), inplace=True)
182179
x_train, y_train = df_train["feature"], df_train["label"]
183180
x_valid, y_valid = df_valid["feature"], df_valid["label"]
181+
save_path = get_or_create_path(save_path)
184182

185183
stop_steps = 0
186184
train_loss = 0
@@ -207,12 +205,16 @@ def fit(
207205
best_score = val_score
208206
stop_steps = 0
209207
best_epoch = epoch_idx
208+
best_param = copy.deepcopy(self.tabnet_model.state_dict())
210209
else:
211210
stop_steps += 1
212211
if stop_steps >= self.early_stop:
213212
self.logger.info("early stop")
214213
break
214+
215215
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
216+
self.tabnet_model.load_state_dict(best_param)
217+
torch.save(best_param, save_path)
216218

217219
def predict(self, dataset):
218220
if not self.fitted:

qlib/utils/__init__.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import numpy as np
2525
import pandas as pd
2626
from pathlib import Path
27-
from typing import Union, Tuple
27+
from typing import Union, Tuple, Text, Optional
2828

2929
from ..config import C
3030
from ..log import get_module_logger, set_log_with_config
@@ -276,23 +276,31 @@ def default(self, o):
276276
return changes
277277

278278

279-
def create_save_path(save_path=None):
280-
"""Create save path
279+
def get_or_create_path(path: Optional[Text] = None, return_dir: bool = False):
280+
"""Create or get a file or directory given the path and return_dir.
281281
282282
Parameters
283283
----------
284-
save_path: str
284+
path: a string indicates the path or None indicates creating a temporary path.
285+
return_dir: if True, create and return a directory; otherwise c&r a file.
285286
286287
"""
287-
if save_path:
288-
if not os.path.exists(save_path):
289-
os.makedirs(save_path)
288+
if path:
289+
if return_dir and not os.path.exists(path):
290+
os.makedirs(path)
291+
elif not return_dir: # return a file, thus we need to create its parent directory
292+
xpath = os.path.abspath(os.path.join(path, ".."))
293+
if not os.path.exists(xpath):
294+
os.makedirs(xpath)
290295
else:
291296
temp_dir = os.path.expanduser("~/tmp")
292297
if not os.path.exists(temp_dir):
293298
os.makedirs(temp_dir)
294-
_, save_path = tempfile.mkstemp(dir=temp_dir)
295-
return save_path
299+
if return_dir:
300+
_, path = tempfile.mkdtemp(dir=temp_dir)
301+
else:
302+
_, path = tempfile.mkstemp(dir=temp_dir)
303+
return path
296304

297305

298306
@contextlib.contextmanager

0 commit comments

Comments
 (0)