Skip to content

Commit 1418417

Browse files
committed
fix automatic update of daily frequency data
1 parent bab50e8 commit 1418417

File tree

8 files changed

+77
-62
lines changed

8 files changed

+77
-62
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,11 @@ We recommend users to prepare their own data if they have a high-quality dataset
171171
```
172172
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
173173
```
174-
* **script path**: *qlib/scripts/data_collector/yahoo/collector.py*
174+
* **script path**: *scripts/data_collector/yahoo/collector.py*
175175
176176
* Manual update of data
177177
```
178-
python qlib/scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
178+
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
179179
```
180180
* *trading_date*: start of trading day
181181
* *end_date*: end of trading day(not included)

docs/component/data.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,13 @@ Automatic update of daily frequency data
8282
8383
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
8484
85-
- **script path**: *qlib/scripts/data_collector/yahoo/collector.py*
85+
- **script path**: *scripts/data_collector/yahoo/collector.py*
8686

8787
- Manual update of data
8888

8989
.. code-block:: bash
9090
91-
python qlib/scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
91+
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
9292
9393
- *trading_date*: start of trading day
9494
- *end_date*: end of trading day(not included)

qlib/contrib/report/analysis_model/analysis_model_performance.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure:
7979
:param dist:
8080
:return:
8181
"""
82+
# NOTE: plotly.tools.mpl_to_plotly not actively maintained, resulting in errors in the new version of matplotlib,
83+
# ref: https://github.com/plotly/plotly.py/issues/2913#issuecomment-730071567
84+
# removing plotly.tools.mpl_to_plotly for greater compatibility with matplotlib versions
8285
_plt_fig = sm.qqplot(data.dropna(), dist=dist, fit=True, line="45")
8386
plt.close(_plt_fig)
8487
qqplot_data = _plt_fig.gca().lines
@@ -88,9 +91,7 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure:
8891
{
8992
"type": "scatter",
9093
"x": qqplot_data[0].get_xdata(),
91-
# 'x': [0, 1],
9294
"y": qqplot_data[0].get_ydata(),
93-
# 'y': [1, 2],
9495
"mode": "markers",
9596
"marker": {"color": "#19d3f3"},
9697
}
@@ -100,9 +101,7 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure:
100101
{
101102
"type": "scatter",
102103
"x": qqplot_data[1].get_xdata(),
103-
# 'x': [0, 1],
104104
"y": qqplot_data[1].get_ydata(),
105-
# 'y': [1, 2],
106105
"mode": "lines",
107106
"line": {"color": "#636efa"},
108107
}

scripts/data_collector/base.py

+10-17
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pandas as pd
1414
from tqdm import tqdm
1515
from loguru import logger
16+
from joblib import Parallel, delayed
1617
from qlib.utils import code_to_fname
1718

1819

@@ -186,20 +187,12 @@ def cache_small_data(self, symbol, df):
186187
def _collector(self, instrument_list):
187188

188189
error_symbol = []
189-
with tqdm(total=len(instrument_list)) as p_bar:
190-
if self.max_workers is not None and self.max_workers > 1:
191-
logger.info(f"concurrent collector, max_workers: {self.max_workers}")
192-
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
193-
for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)):
194-
if _result != self.NORMAL_FLAG:
195-
error_symbol.append(_symbol)
196-
p_bar.update()
197-
else:
198-
for _symbol in instrument_list:
199-
_result = self._simple_collector(_symbol)
200-
if _result != self.NORMAL_FLAG:
201-
error_symbol.append(_symbol)
202-
p_bar.update()
190+
res = Parallel(n_jobs=self.max_workers)(
191+
delayed(self._simple_collector)(_inst) for _inst in tqdm(instrument_list)
192+
)
193+
for _symbol, _result in zip(instrument_list, res):
194+
if _result != self.NORMAL_FLAG:
195+
error_symbol.append(_symbol)
203196
print(error_symbol)
204197
logger.info(f"error symbol nums: {len(error_symbol)}")
205198
logger.info(f"current get symbol nums: {len(instrument_list)}")
@@ -365,7 +358,7 @@ def download_data(
365358
start=None,
366359
end=None,
367360
interval="1d",
368-
check_data_length=False,
361+
check_data_length: int = None,
369362
limit_nums=None,
370363
):
371364
"""download data from Internet
@@ -382,8 +375,8 @@ def download_data(
382375
start datetime, default "2000-01-01"
383376
end: str
384377
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
385-
check_data_length: bool
386-
check data length, by default False
378+
check_data_length: int
379+
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
387380
limit_nums: int
388381
using for debug, by default None
389382

scripts/data_collector/fund/collector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def download_data(
254254
start=None,
255255
end=None,
256256
interval="1d",
257-
check_data_length=None,
257+
check_data_length: int = None,
258258
limit_nums=None,
259259
):
260260
"""download data from Internet

scripts/data_collector/yahoo/README.md

+14-12
Original file line numberDiff line numberDiff line change
@@ -140,22 +140,24 @@ pip install -r requirements.txt
140140
```
141141
* * * * 1-5 python <script path> update_data_to_bin --qlib_data_1d_dir <user data dir>
142142
```
143-
* **script path**: *qlib/scripts/data_collector/yahoo/collector.py*
143+
* **script path**: *scripts/data_collector/yahoo/collector.py*
144144
145145
* Manual update of data
146146
```
147-
python qlib/scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
147+
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
148148
```
149-
* *trading_date*: start of trading day
150-
* *end_date*: end of trading day(not included)
151-
152-
* qlib/scripts/data_collector/yahoo/collector.py update_data_to_bin parameters:
153-
* *source_dir*: The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
154-
* *normalize_dir*: Directory for normalize data, default "Path(__file__).parent/normalize"
155-
* *qlib_data_1d_dir*: the qlib data to be updated for yahoo, usually from: [download qlib data](https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)
156-
* *trading_date*: trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
157-
* *end_date*: end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
158-
* *region*: region, value from ["CN", "US"], default "CN"
149+
* `trading_date`: start of trading day
150+
* `end_date`: end of trading day(not included)
151+
* `check_data_length`: check the number of rows per *symbol*, by default `None`
152+
> if `len(symbol_df) < check_data_length`, it will be re-fetched, with the number of re-fetches coming from the `max_collector_count` parameter
153+
154+
* `scripts/data_collector/yahoo/collector.py update_data_to_bin` parameters:
155+
* `source_dir`: The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
156+
* `normalize_dir`: Directory for normalize data, default "Path(__file__).parent/normalize"
157+
* `qlib_data_1d_dir`: the qlib data to be updated for yahoo, usually from: [download qlib data](https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)
158+
* `trading_date`: trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
159+
* `end_date`: end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
160+
* `region`: region, value from ["CN", "US"], default "CN"
159161
160162
161163
## Using qlib data

scripts/data_collector/yahoo/collector.py

+44-24
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import datetime
99
import importlib
1010
from abc import ABC
11+
import multiprocessing
1112
from pathlib import Path
1213
from typing import Iterable
1314

@@ -49,7 +50,7 @@ def __init__(
4950
max_workers=4,
5051
max_collector_count=2,
5152
delay=0,
52-
check_data_length: bool = False,
53+
check_data_length: int = None,
5354
limit_nums: int = None,
5455
):
5556
"""
@@ -70,8 +71,8 @@ def __init__(
7071
start datetime, default None
7172
end: str
7273
end datetime, default None
73-
check_data_length: bool
74-
check data length, by default False
74+
check_data_length: int
75+
check data length, by default None
7576
limit_nums: int
7677
using for debug, by default None
7778
"""
@@ -311,8 +312,8 @@ def normalize_yahoo(
311312
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {symbol_field_name}] = np.nan
312313
_tmp_series = df["close"].fillna(method="ffill")
313314
_tmp_shift_series = _tmp_series.shift(1)
314-
if last_close is not None and isinstance(last_close, (int, float)):
315-
_tmp_shift_series.iloc[0] = last_close
315+
if last_close is not None:
316+
_tmp_shift_series.iloc[0] = float(last_close)
316317
df["change"] = _tmp_series / _tmp_shift_series - 1
317318
columns += ["change"]
318319
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
@@ -408,41 +409,44 @@ def __init__(
408409
symbol field name, default is symbol
409410
"""
410411
super(YahooNormalize1dExtend, self).__init__(date_field_name, symbol_field_name)
411-
self._end_date, self._old_close = self._get_old_data(old_qlib_data_dir)
412-
self._end_date = pd.Timestamp(self._end_date).strftime(self.DAILY_FORMAT)
412+
self._first_close_field = "first_close"
413+
self._ori_close_field = "ori_close"
414+
self.old_qlib_data = self._get_old_data(old_qlib_data_dir)
413415

414416
def _get_old_data(self, qlib_data_dir: [str, Path]):
415417
import qlib
416418
from qlib.data import D
417419

418420
qlib_data_dir = str(Path(qlib_data_dir).expanduser().resolve())
419421
qlib.init(provider_uri=qlib_data_dir, expression_cache=None, dataset_cache=None)
420-
df = D.features(D.instruments("all"), ["$close/$factor"])
421-
df.columns = ["close"]
422-
return D.calendar()[-1], df
422+
df = D.features(D.instruments("all"), ["$close/$factor", "$adjclose/$close"])
423+
df.columns = [self._ori_close_field, self._first_close_field]
424+
return df
425+
426+
def _get_close(self, df: pd.DataFrame, field_name: str):
427+
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper()
428+
_df = self.old_qlib_data.loc(axis=0)[_symbol]
429+
_close = _df.loc[_df.last_valid_index()][field_name]
430+
return _close
423431

424432
def _get_first_close(self, df: pd.DataFrame) -> float:
425-
_symbol = df.iloc[0][self._symbol_field_name]
426433
try:
427-
_df = self._old_close.loc(axis=0)[_symbol.upper()]
428-
_close = _df.loc[_df.first_valid_index()]["close"]
434+
_close = self._get_close(df, field_name=self._first_close_field)
429435
except KeyError:
430436
_close = super(YahooNormalize1dExtend, self)._get_first_close(df)
431437
return _close
432438

433439
def _get_last_close(self, df: pd.DataFrame) -> float:
434-
_symbol = df.iloc[0][self._symbol_field_name]
435440
try:
436-
_df = self._old_close.loc(axis=0)[_symbol.upper()]
437-
_close = _df.loc[_df.last_valid_index()]["close"]
441+
_close = self._get_close(df, field_name=self._ori_close_field)
438442
except KeyError:
439443
_close = None
440444
return _close
441445

442446
def _get_last_date(self, df: pd.DataFrame) -> pd.Timestamp:
443-
_symbol = df.iloc[0][self._symbol_field_name]
447+
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper()
444448
try:
445-
_df = self._old_close.loc(axis=0)[_symbol.upper()]
449+
_df = self.old_qlib_data.loc(axis=0)[_symbol]
446450
_date = _df.index.max()
447451
except KeyError:
448452
_date = None
@@ -901,7 +905,7 @@ def normalize_data_1d_extend(
901905
def download_today_data(
902906
self,
903907
max_collector_count=2,
904-
delay=0,
908+
delay=0.5,
905909
check_data_length=None,
906910
limit_nums=None,
907911
):
@@ -912,7 +916,7 @@ def download_today_data(
912916
max_collector_count: int
913917
default 2
914918
delay: float
915-
time.sleep(delay), default 0
919+
time.sleep(delay), default 0.5
916920
check_data_length: int
917921
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
918922
limit_nums: int
@@ -947,7 +951,14 @@ def download_today_data(
947951
limit_nums,
948952
)
949953

950-
def update_data_to_bin(self, qlib_data_1d_dir: str, trading_date: str = None, end_date: str = None):
954+
def update_data_to_bin(
955+
self,
956+
qlib_data_1d_dir: str,
957+
trading_date: str = None,
958+
end_date: str = None,
959+
check_data_length: int = None,
960+
delay: float = 1,
961+
):
951962
"""update yahoo data to bin
952963
953964
Parameters
@@ -959,7 +970,10 @@ def update_data_to_bin(self, qlib_data_1d_dir: str, trading_date: str = None, en
959970
trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
960971
end_date: str
961972
end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
962-
973+
check_data_length: int
974+
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
975+
delay: float
976+
time.sleep(delay), default 1
963977
Notes
964978
-----
965979
If the data in qlib_data_dir is incomplete, np.nan will be populated to trading_date for the previous trading day
@@ -987,8 +1001,14 @@ def update_data_to_bin(self, qlib_data_1d_dir: str, trading_date: str = None, en
9871001
GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region)
9881002

9891003
# download data from yahoo
990-
self.download_data(delay=1, start=trading_date, end=end_date, check_data_length=1)
991-
1004+
# NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1
1005+
self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length)
1006+
# NOTE: a larger max_workers setting here would be faster
1007+
self.max_workers = (
1008+
max(multiprocessing.cpu_count() - 2, 1)
1009+
if self.max_workers is None or self.max_workers <= 1
1010+
else self.max_workers
1011+
)
9921012
# normalize data
9931013
self.normalize_data_1d_extend(qlib_data_1d_dir)
9941014

scripts/data_collector/yahoo/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ tqdm
77
lxml
88
loguru
99
yahooquery
10+
joblib

0 commit comments

Comments
 (0)