|
9 | 9 | import importlib
|
10 | 10 | from abc import ABC
|
11 | 11 | from pathlib import Path
|
12 |
| -from typing import Iterable, Type |
| 12 | +from typing import Iterable |
13 | 13 |
|
14 | 14 | import fire
|
15 | 15 | import requests
|
|
18 | 18 | from loguru import logger
|
19 | 19 | from yahooquery import Ticker
|
20 | 20 | from dateutil.tz import tzlocal
|
21 |
| -from qlib.utils import code_to_fname, fname_to_code |
| 21 | + |
| 22 | +from qlib.tests.data import GetData |
| 23 | +from qlib.utils import code_to_fname, fname_to_code, exists_qlib_data |
22 | 24 | from qlib.config import REG_CN as REGION_CN
|
23 | 25 |
|
24 | 26 | CUR_DIR = Path(__file__).resolve().parent
|
25 | 27 | sys.path.append(str(CUR_DIR.parent.parent))
|
| 28 | + |
| 29 | +from dump_bin import DumpDataUpdate |
26 | 30 | from data_collector.base import BaseCollector, BaseNormalize, BaseRun, Normalize
|
27 | 31 | from data_collector.utils import (
|
28 | 32 | deco_retry,
|
@@ -153,7 +157,10 @@ def _get_simple(start_, end_):
|
153 | 157 |
|
154 | 158 | _result = None
|
155 | 159 | if interval == self.INTERVAL_1d:
|
156 |
| - _result = _get_simple(start_datetime, end_datetime) |
| 160 | + try: |
| 161 | + _result = _get_simple(start_datetime, end_datetime) |
| 162 | + except ValueError as e: |
| 163 | + pass |
157 | 164 | elif interval == self.INTERVAL_1min:
|
158 | 165 | _res = []
|
159 | 166 | _start = self.start_datetime
|
@@ -184,7 +191,7 @@ def download_index_data(self):
|
184 | 191 |
|
185 | 192 | class YahooCollectorCN(YahooCollector, ABC):
|
186 | 193 | def get_instrument_list(self):
|
187 |
| - logger.info("get HS stock symbos......") |
| 194 | + logger.info("get HS stock symbols......") |
188 | 195 | symbols = get_hs_stock_symbols()
|
189 | 196 | logger.info(f"get {len(symbols)} symbols.")
|
190 | 197 | return symbols
|
@@ -233,9 +240,9 @@ def download_index_data(self):
|
233 | 240 |
|
234 | 241 |
|
235 | 242 | class YahooCollectorCN1min(YahooCollectorCN):
|
236 |
| - def download_index_data(self): |
237 |
| - # TODO: 1m |
238 |
| - logger.warning(f"{self.__class__.__name__} {self.interval} does not support: download_index_data") |
| 243 | + def get_instrument_list(self): |
| 244 | + symbols = super(YahooCollectorCN1min, self).get_instrument_list() |
| 245 | + return symbols + ["000300.ss", "000905.ss", "00903.ss"] |
239 | 246 |
|
240 | 247 |
|
241 | 248 | class YahooCollectorUS(YahooCollector, ABC):
|
@@ -450,10 +457,12 @@ def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
450 | 457 | _max_date = df.index.max()
|
451 | 458 | df = df.reindex(self._calendar_list).loc[:_max_date].reset_index()
|
452 | 459 | df = df[df[self._date_field_name] > _last_date]
|
| 460 | + if df.empty: |
| 461 | + return pd.DataFrame() |
453 | 462 | _si = df["close"].first_valid_index()
|
454 | 463 | if _si > df.index[0]:
|
455 | 464 | logger.warning(
|
456 |
| - f"{df.iloc[0][self._symbol_field_name]} missing data: {df.loc[:_si][self._date_field_name]}" |
| 465 | + f"{df.loc[_si][self._symbol_field_name]} missing data: {df.loc[:_si-1][self._date_field_name].to_list()}" |
457 | 466 | )
|
458 | 467 | # normalize
|
459 | 468 | df = self.normalize_yahoo(
|
@@ -661,7 +670,7 @@ def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
661 | 670 |
|
662 | 671 | def symbol_to_yahoo(self, symbol):
|
663 | 672 | if "." not in symbol:
|
664 |
| - _exchange = symbol[:2] |
| 673 | + _exchange = symbol[:2].lower() |
665 | 674 | _exchange = "ss" if _exchange == "sh" else _exchange
|
666 | 675 | symbol = symbol[2:] + "." + _exchange
|
667 | 676 | return symbol
|
@@ -864,7 +873,7 @@ def normalize_data_1d_extend(
|
864 | 873 | yc.normalize()
|
865 | 874 |
|
866 | 875 | def normalize_data_1min_cn_offline(
|
867 |
| - self, qlib_data_1d_dir, date_field_name: str = "date", symbol_field_name: str = "symbol" |
| 876 | + self, qlib_data_1d_dir: str, date_field_name: str = "date", symbol_field_name: str = "symbol" |
868 | 877 | ):
|
869 | 878 | """Normalised to 1min using local 1d data
|
870 | 879 |
|
@@ -942,6 +951,72 @@ def download_today_data(
|
942 | 951 | limit_nums,
|
943 | 952 | )
|
944 | 953 |
|
| 954 | + def update_data_to_bin(self, qlib_data_1d_dir: str, trading_date: str = None, end_date: str = None): |
| 955 | + """update yahoo data to bin |
| 956 | +
|
| 957 | + Parameters |
| 958 | + ---------- |
| 959 | + qlib_data_1d_dir: str |
| 960 | + the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data |
| 961 | +
|
| 962 | + trading_date: str |
| 963 | + trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")`` |
| 964 | + end_date: str |
| 965 | + end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end) |
| 966 | +
|
| 967 | + Notes |
| 968 | + ----- |
| 969 | + If the data in qlib_data_dir is incomplete, np.nan will be populated to trading_date for the previous trading day |
| 970 | +
|
| 971 | + Examples |
| 972 | + ------- |
| 973 | + $ python collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date> |
| 974 | + # get 1m data |
| 975 | + """ |
| 976 | + |
| 977 | + if self.interval.lower() != "1d": |
| 978 | + logger.warning(f"currently supports 1d data updates: --interval 1d") |
| 979 | + |
| 980 | + # start/end date |
| 981 | + if trading_date is None: |
| 982 | + trading_date = datetime.datetime.now().strftime("%Y-%m-%d") |
| 983 | + logger.warning(f"trading_date is None, use the current date: {trading_date}") |
| 984 | + |
| 985 | + if end_date is None: |
| 986 | + end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") |
| 987 | + |
| 988 | + # download qlib 1d data |
| 989 | + qlib_data_1d_dir = Path(qlib_data_1d_dir).expanduser().resolve() |
| 990 | + if not exists_qlib_data(qlib_data_1d_dir): |
| 991 | + GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region) |
| 992 | + |
| 993 | + # download data from yahoo |
| 994 | + self.download_data(delay=1, start=trading_date, end=end_date, check_data_length=1) |
| 995 | + |
| 996 | + # normalize data |
| 997 | + self.normalize_data_1d_extend(str(qlib_data_1d_dir)) |
| 998 | + |
| 999 | + # dump bin |
| 1000 | + _dump = DumpDataUpdate( |
| 1001 | + csv_path=self.normalize_dir, |
| 1002 | + qlib_dir=qlib_data_1d_dir, |
| 1003 | + exclude_fields="symbol,date", |
| 1004 | + max_workers=self.max_workers, |
| 1005 | + ) |
| 1006 | + _dump.dump() |
| 1007 | + |
| 1008 | + # parse index |
| 1009 | + _region = self.region.lower() |
| 1010 | + if _region not in ["cn", "us"]: |
| 1011 | + logger.warning(f"Unsupported region: region={_region}, component downloads will be ignored") |
| 1012 | + return |
| 1013 | + index_list = ["CSI100", "CSI300"] if _region == "cn" else ["SP500", "NASDAQ100", "DJIA", "SP400"] |
| 1014 | + get_instruments = getattr( |
| 1015 | + importlib.import_module(f"data_collector.{_region}_index.collector"), "get_instruments" |
| 1016 | + ) |
| 1017 | + for _index in index_list: |
| 1018 | + get_instruments(str(qlib_data_1d_dir), _index) |
| 1019 | + |
945 | 1020 |
|
946 | 1021 | if __name__ == "__main__":
|
947 | 1022 | fire.Fire(Run)
|
0 commit comments