Skip to content

Commit 3033fdf

Browse files
authored
Merge pull request #444 from zhupr/feature_importance
add get_feature_importance to model interpret
2 parents 43cad1e + ef11a9d commit 3033fdf

26 files changed

+592
-822
lines changed

Diff for: examples/highfreq/workflow.py

+3-16
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,13 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33

4-
import sys
54
import fire
6-
from pathlib import Path
75

86
import qlib
97
import pickle
10-
import numpy as np
11-
import pandas as pd
128
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
13-
from qlib.contrib.model.gbdt import LGBModel
14-
from qlib.contrib.data.handler import Alpha158
15-
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
16-
from qlib.contrib.evaluate import (
17-
backtest as normal_backtest,
18-
risk_analysis,
19-
)
20-
21-
from qlib.utils import init_instance_by_config, exists_qlib_data
9+
10+
from qlib.utils import init_instance_by_config
2211
from qlib.data.dataset.handler import DataHandlerLP
2312
from qlib.data.ops import Operators
2413
from qlib.data.data import Cal
@@ -96,9 +85,7 @@ def _init_qlib(self):
9685
# use yahoo_cn_1min data
9786
QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
9887
provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
99-
if not exists_qlib_data(provider_uri):
100-
print(f"Qlib data is not found in {provider_uri}")
101-
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN)
88+
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN, exists_skip=True)
10289
qlib.init(**QLIB_INIT_CONFIG)
10390

10491
def _prepare_calender_cache(self):
+14-44
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,9 @@
11
import qlib
2-
from qlib.config import REG_CN
3-
from qlib.utils import exists_qlib_data, init_instance_by_config
42
import optuna
5-
6-
provider_uri = "~/.qlib/qlib_data/cn_data"
7-
if not exists_qlib_data(provider_uri):
8-
print(f"Qlib data is not found in {provider_uri}")
9-
sys.path.append(str(scripts_dir))
10-
from get_data import GetData
11-
12-
GetData().qlib_data(target_dir=provider_uri, region="cn")
13-
qlib.init(provider_uri=provider_uri, region="cn")
14-
15-
market = "csi300"
16-
benchmark = "SH000300"
17-
18-
data_handler_config = {
19-
"start_time": "2008-01-01",
20-
"end_time": "2020-08-01",
21-
"fit_start_time": "2008-01-01",
22-
"fit_end_time": "2014-12-31",
23-
"instruments": market,
24-
}
25-
dataset_task = {
26-
"dataset": {
27-
"class": "DatasetH",
28-
"module_path": "qlib.data.dataset",
29-
"kwargs": {
30-
"handler": {
31-
"class": "Alpha158",
32-
"module_path": "qlib.contrib.data.handler",
33-
"kwargs": data_handler_config,
34-
},
35-
"segments": {
36-
"train": ("2008-01-01", "2014-12-31"),
37-
"valid": ("2015-01-01", "2016-12-31"),
38-
"test": ("2017-01-01", "2020-08-01"),
39-
},
40-
},
41-
},
42-
}
43-
dataset = init_instance_by_config(dataset_task["dataset"])
3+
from qlib.config import REG_CN
4+
from qlib.utils import init_instance_by_config
5+
from qlib.tests.config import CSI300_DATASET_CONFIG
6+
from qlib.tests.data import GetData
447

458

469
def objective(trial):
@@ -65,12 +28,19 @@ def objective(trial):
6528
},
6629
},
6730
}
68-
6931
evals_result = dict()
7032
model = init_instance_by_config(task["model"])
7133
model.fit(dataset, evals_result=evals_result)
7234
return min(evals_result["valid"])
7335

7436

75-
study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
76-
study.optimize(objective, n_jobs=6)
37+
if __name__ == "__main__":
38+
39+
provider_uri = "~/.qlib/qlib_data/cn_data"
40+
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
41+
qlib.init(provider_uri=provider_uri, region="cn")
42+
43+
dataset = init_instance_by_config(CSI300_DATASET_CONFIG)
44+
45+
study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
46+
study.optimize(objective, n_jobs=6)
+15-42
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,11 @@
11
import qlib
2-
from qlib.config import REG_CN
3-
from qlib.utils import exists_qlib_data, init_instance_by_config
42
import optuna
3+
from qlib.config import REG_CN
4+
from qlib.utils import init_instance_by_config
5+
from qlib.tests.data import GetData
6+
from qlib.tests.config import get_dataset_config, CSI300_MARKET, DATASET_ALPHA360_CLASS
57

6-
provider_uri = "~/.qlib/qlib_data/cn_data"
7-
if not exists_qlib_data(provider_uri):
8-
print(f"Qlib data is not found in {provider_uri}")
9-
sys.path.append(str(scripts_dir))
10-
from get_data import GetData
11-
12-
GetData().qlib_data(target_dir=provider_uri, region="cn")
13-
qlib.init(provider_uri=provider_uri, region="cn")
14-
15-
market = "csi300"
16-
benchmark = "SH000300"
17-
18-
data_handler_config = {
19-
"start_time": "2008-01-01",
20-
"end_time": "2020-08-01",
21-
"fit_start_time": "2008-01-01",
22-
"fit_end_time": "2014-12-31",
23-
"instruments": market,
24-
}
25-
dataset_task = {
26-
"dataset": {
27-
"class": "DatasetH",
28-
"module_path": "qlib.data.dataset",
29-
"kwargs": {
30-
"handler": {
31-
"class": "Alpha360",
32-
"module_path": "qlib.contrib.data.handler",
33-
"kwargs": data_handler_config,
34-
},
35-
"segments": {
36-
"train": ("2008-01-01", "2014-12-31"),
37-
"valid": ("2015-01-01", "2016-12-31"),
38-
"test": ("2017-01-01", "2020-08-01"),
39-
},
40-
},
41-
},
42-
}
43-
dataset = init_instance_by_config(dataset_task["dataset"])
8+
DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA360_CLASS)
449

4510

4611
def objective(trial):
@@ -72,5 +37,13 @@ def objective(trial):
7237
return min(evals_result["valid"])
7338

7439

75-
study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
76-
study.optimize(objective, n_jobs=6)
40+
if __name__ == "__main__":
41+
42+
provider_uri = "~/.qlib/qlib_data/cn_data"
43+
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
44+
qlib.init(provider_uri=provider_uri, region=REG_CN)
45+
46+
dataset = init_instance_by_config(DATASET_CONFIG)
47+
48+
study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
49+
study.optimize(objective, n_jobs=6)

Diff for: examples/model_interpreter/feature.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
5+
import qlib
6+
from qlib.config import REG_CN
7+
8+
from qlib.utils import init_instance_by_config
9+
from qlib.tests.data import GetData
10+
from qlib.tests.config import CSI300_GBDT_TASK
11+
12+
13+
if __name__ == "__main__":
14+
15+
# use default data
16+
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
17+
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
18+
19+
qlib.init(provider_uri=provider_uri, region=REG_CN)
20+
21+
###################################
22+
# train model
23+
###################################
24+
# model initialization
25+
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
26+
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
27+
model.fit(dataset)
28+
29+
# get model feature importance
30+
feature_importance = model.get_feature_importance()
31+
print("feature importance:")
32+
print(feature_importance)

Diff for: examples/model_rolling/task_manager_rolling.py

+4-58
Original file line numberDiff line numberDiff line change
@@ -17,63 +17,7 @@
1717
from qlib.workflow.task.collect import RecorderCollector
1818
from qlib.model.ens.group import RollingGroup
1919
from qlib.model.trainer import TrainerRM
20-
21-
22-
data_handler_config = {
23-
"start_time": "2008-01-01",
24-
"end_time": "2020-08-01",
25-
"fit_start_time": "2008-01-01",
26-
"fit_end_time": "2014-12-31",
27-
"instruments": "csi100",
28-
}
29-
30-
dataset_config = {
31-
"class": "DatasetH",
32-
"module_path": "qlib.data.dataset",
33-
"kwargs": {
34-
"handler": {
35-
"class": "Alpha158",
36-
"module_path": "qlib.contrib.data.handler",
37-
"kwargs": data_handler_config,
38-
},
39-
"segments": {
40-
"train": ("2008-01-01", "2014-12-31"),
41-
"valid": ("2015-01-01", "2016-12-31"),
42-
"test": ("2017-01-01", "2020-08-01"),
43-
},
44-
},
45-
}
46-
47-
record_config = [
48-
{
49-
"class": "SignalRecord",
50-
"module_path": "qlib.workflow.record_temp",
51-
},
52-
{
53-
"class": "SigAnaRecord",
54-
"module_path": "qlib.workflow.record_temp",
55-
},
56-
]
57-
58-
# use lgb
59-
task_lgb_config = {
60-
"model": {
61-
"class": "LGBModel",
62-
"module_path": "qlib.contrib.model.gbdt",
63-
},
64-
"dataset": dataset_config,
65-
"record": record_config,
66-
}
67-
68-
# use xgboost
69-
task_xgboost_config = {
70-
"model": {
71-
"class": "XGBModel",
72-
"module_path": "qlib.contrib.model.xgboost",
73-
},
74-
"dataset": dataset_config,
75-
"record": record_config,
76-
}
20+
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
7721

7822

7923
class RollingTaskExample:
@@ -85,11 +29,13 @@ def __init__(
8529
task_db_name="rolling_db",
8630
experiment_name="rolling_exp",
8731
task_pool="rolling_task",
88-
task_config=[task_xgboost_config, task_lgb_config],
32+
task_config=None,
8933
rolling_step=550,
9034
rolling_type=RollingGen.ROLL_SD,
9135
):
9236
# TaskManager config
37+
if task_config is None:
38+
task_config = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
9339
mongo_conf = {
9440
"task_url": task_url,
9541
"task_db_name": task_db_name,

Diff for: examples/online_srv/online_management_simulate.py

+4-58
Original file line numberDiff line numberDiff line change
@@ -13,63 +13,7 @@
1313
from qlib.workflow.online.strategy import RollingStrategy
1414
from qlib.workflow.task.gen import RollingGen
1515
from qlib.workflow.task.manage import TaskManager
16-
17-
18-
data_handler_config = {
19-
"start_time": "2018-01-01",
20-
"end_time": "2018-10-31",
21-
"fit_start_time": "2018-01-01",
22-
"fit_end_time": "2018-03-31",
23-
"instruments": "csi100",
24-
}
25-
26-
dataset_config = {
27-
"class": "DatasetH",
28-
"module_path": "qlib.data.dataset",
29-
"kwargs": {
30-
"handler": {
31-
"class": "Alpha158",
32-
"module_path": "qlib.contrib.data.handler",
33-
"kwargs": data_handler_config,
34-
},
35-
"segments": {
36-
"train": ("2018-01-01", "2018-03-31"),
37-
"valid": ("2018-04-01", "2018-05-31"),
38-
"test": ("2018-06-01", "2018-09-10"),
39-
},
40-
},
41-
}
42-
43-
record_config = [
44-
{
45-
"class": "SignalRecord",
46-
"module_path": "qlib.workflow.record_temp",
47-
},
48-
{
49-
"class": "SigAnaRecord",
50-
"module_path": "qlib.workflow.record_temp",
51-
},
52-
]
53-
54-
# use lgb model
55-
task_lgb_config = {
56-
"model": {
57-
"class": "LGBModel",
58-
"module_path": "qlib.contrib.model.gbdt",
59-
},
60-
"dataset": dataset_config,
61-
"record": record_config,
62-
}
63-
64-
# use xgboost model
65-
task_xgboost_config = {
66-
"model": {
67-
"class": "XGBModel",
68-
"module_path": "qlib.contrib.model.xgboost",
69-
},
70-
"dataset": dataset_config,
71-
"record": record_config,
72-
}
16+
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
7317

7418

7519
class OnlineSimulationExample:
@@ -84,7 +28,7 @@ def __init__(
8428
rolling_step=80,
8529
start_time="2018-09-10",
8630
end_time="2018-10-31",
87-
tasks=[task_xgboost_config, task_lgb_config],
31+
tasks=None,
8832
):
8933
"""
9034
Init OnlineManagerExample.
@@ -101,6 +45,8 @@ def __init__(
10145
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
10246
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
10347
"""
48+
if tasks is None:
49+
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
10450
self.exp_name = exp_name
10551
self.task_pool = task_pool
10652
self.start_time = start_time

0 commit comments

Comments
 (0)