Skip to content

Commit 46714ad

Browse files
committed
modify the YahooNormalize1min factor calculation
1 parent 99fb496 commit 46714ad

File tree

2 files changed

+118
-126
lines changed

2 files changed

+118
-126
lines changed

qlib/contrib/report/analysis_model/analysis_model_performance.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -84,30 +84,29 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure:
8484
qqplot_data = _plt_fig.gca().lines
8585
fig = go.Figure()
8686

87-
fig.add_trace({
88-
'type': 'scatter',
89-
'x': qqplot_data[0].get_xdata(),
90-
# 'x': [0, 1],
91-
'y': qqplot_data[0].get_ydata(),
92-
# 'y': [1, 2],
93-
'mode': 'markers',
94-
'marker': {
95-
'color': '#19d3f3'
96-
}
97-
})
98-
99-
fig.add_trace({
100-
'type': 'scatter',
101-
'x': qqplot_data[1].get_xdata(),
102-
# 'x': [0, 1],
103-
'y': qqplot_data[1].get_ydata(),
104-
# 'y': [1, 2],
105-
'mode': 'lines',
106-
'line': {
107-
'color': '#636efa'
87+
fig.add_trace(
88+
{
89+
"type": "scatter",
90+
"x": qqplot_data[0].get_xdata(),
91+
# 'x': [0, 1],
92+
"y": qqplot_data[0].get_ydata(),
93+
# 'y': [1, 2],
94+
"mode": "markers",
95+
"marker": {"color": "#19d3f3"},
10896
}
97+
)
10998

110-
})
99+
fig.add_trace(
100+
{
101+
"type": "scatter",
102+
"x": qqplot_data[1].get_xdata(),
103+
# 'x': [0, 1],
104+
"y": qqplot_data[1].get_ydata(),
105+
# 'y': [1, 2],
106+
"mode": "lines",
107+
"line": {"color": "#636efa"},
108+
}
109+
)
111110
del qqplot_data
112111
return fig
113112

scripts/data_collector/yahoo/collector.py

Lines changed: 97 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -478,8 +478,8 @@ class YahooNormalize1min(YahooNormalize, ABC):
478478
PM_RANGE = None # type: tuple # eg: ("13:00:00", "14:59:00")
479479

480480
# Whether the trading day of 1min data is consistent with 1d
481-
CONSISTENT_1d = False
482-
CALC_PAUSED_NUM = False
481+
CONSISTENT_1d = True
482+
CALC_PAUSED_NUM = True
483483

484484
@property
485485
def calendar_list_1d(self):
@@ -500,7 +500,7 @@ def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
500500
Returns
501501
------
502502
data_1d: pd.DataFrame
503-
set(data_1d.columns) - set([self._date_field_name, self._symbol_field_name, "paused", "volume", "factor"]) == {}
503+
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
504504
505505
"""
506506
data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end)
@@ -516,14 +516,15 @@ def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
516516
if df.empty:
517517
return df
518518
df = df.copy()
519+
df = df.sort_values(self._date_field_name)
519520
symbol = df.iloc[0][self._symbol_field_name]
520521
# get 1d data from yahoo
521522
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
522523
_end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT)
523524
data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end)
524525
data_1d = data_1d.copy()
525526
if data_1d is None or data_1d.empty:
526-
df["factor"] = 1
527+
df["factor"] = 1 / df.loc[df["close"].first_valid_index()]
527528
# TODO: np.nan or 1 or 0
528529
df["paused"] = np.nan
529530
else:
@@ -534,9 +535,13 @@ def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
534535
data_1d = data_1d.set_index(self._date_field_name)
535536

536537
# add factor from 1d data
538+
# NOTE: yahoo 1d data info:
539+
# - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits.
540+
# - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits.
541+
# - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)`
537542
df["date_tmp"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
538543
df.set_index("date_tmp", inplace=True)
539-
df.loc[:, "factor"] = data_1d["factor"]
544+
df.loc[:, "factor"] = data_1d["close"] / df["close"]
540545
df.loc[:, "paused"] = data_1d["paused"]
541546
df.reset_index("date_tmp", drop=True, inplace=True)
542547

@@ -619,6 +624,61 @@ def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
619624
raise NotImplementedError("rewrite _get_1d_calendar_list")
620625

621626

627+
class YahooNormalize1minOffline(YahooNormalize1min):
628+
"""Normalised to 1min using local 1d data"""
629+
630+
def __init__(
631+
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
632+
):
633+
"""
634+
635+
Parameters
636+
----------
637+
qlib_data_1d_dir: str, Path
638+
the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data
639+
date_field_name: str
640+
date field name, default is date
641+
symbol_field_name: str
642+
symbol field name, default is symbol
643+
"""
644+
self.qlib_data_1d_dir = qlib_data_1d_dir
645+
super(YahooNormalize1minOffline, self).__init__(date_field_name, symbol_field_name)
646+
self._all_1d_data = self._get_all_1d_data()
647+
648+
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
649+
import qlib
650+
from qlib.data import D
651+
652+
qlib.init(provider_uri=self.qlib_data_1d_dir)
653+
return list(D.calendar(freq="day"))
654+
655+
def _get_all_1d_data(self):
656+
import qlib
657+
from qlib.data import D
658+
659+
qlib.init(provider_uri=self.qlib_data_1d_dir)
660+
df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
661+
df.reset_index(inplace=True)
662+
df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True)
663+
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
664+
return df
665+
666+
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
667+
"""get 1d data
668+
669+
Returns
670+
------
671+
data_1d: pd.DataFrame
672+
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
673+
674+
"""
675+
return self._all_1d_data[
676+
(self._all_1d_data[self._symbol_field_name] == symbol.upper())
677+
& (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start))
678+
& (self._all_1d_data[self._date_field_name] < pd.Timestamp(end))
679+
]
680+
681+
622682
class YahooNormalizeUS:
623683
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
624684
# TODO: from MSN
@@ -629,8 +689,8 @@ class YahooNormalizeUS1d(YahooNormalizeUS, YahooNormalize1d):
629689
pass
630690

631691

632-
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min):
633-
CONSISTENT_1d = False
692+
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1minOffline):
693+
CALC_PAUSED_NUM = False
634694

635695
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
636696
# TODO: support 1min
@@ -657,84 +717,24 @@ class YahooNormalizeCN1dExtend(YahooNormalizeCN, YahooNormalize1dExtend):
657717
pass
658718

659719

660-
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
720+
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline):
661721
AM_RANGE = ("09:30:00", "11:29:00")
662722
PM_RANGE = ("13:00:00", "14:59:00")
663723

664-
CONSISTENT_1d = True
665-
CALC_PAUSED_NUM = True
666-
667724
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
668725
return self.generate_1min_from_daily(self.calendar_list_1d)
669726

670727
def symbol_to_yahoo(self, symbol):
671728
if "." not in symbol:
672-
_exchange = symbol[:2].lower()
673-
_exchange = "ss" if _exchange == "sh" else _exchange
729+
_exchange = symbol[:2]
730+
_exchange = ("ss" if _exchange.islower() else "SS") if _exchange.lower() == "sh" else _exchange
674731
symbol = symbol[2:] + "." + _exchange
675732
return symbol
676733

677734
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
678735
return get_calendar_list("ALL")
679736

680737

681-
class YahooNormalizeCN1minOffline(YahooNormalizeCN1min):
682-
"""Normalised to 1min using local 1d data
683-
1d data usually from: Normalised to 1min using local 1d data
684-
"""
685-
686-
def __init__(
687-
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
688-
):
689-
"""
690-
691-
Parameters
692-
----------
693-
qlib_data_1d_dir: str, Path
694-
the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data
695-
date_field_name: str
696-
date field name, default is date
697-
symbol_field_name: str
698-
symbol field name, default is symbol
699-
"""
700-
self.qlib_data_1d_dir = qlib_data_1d_dir
701-
super(YahooNormalizeCN1minOffline, self).__init__(date_field_name, symbol_field_name)
702-
self._all_1d_data = self._get_all_1d_data()
703-
704-
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
705-
import qlib
706-
from qlib.data import D
707-
708-
qlib.init(provider_uri=self.qlib_data_1d_dir)
709-
return list(D.calendar(freq="day"))
710-
711-
def _get_all_1d_data(self):
712-
import qlib
713-
from qlib.data import D
714-
715-
qlib.init(provider_uri=self.qlib_data_1d_dir)
716-
df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor"], freq="day")
717-
df.reset_index(inplace=True)
718-
df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True)
719-
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
720-
return df
721-
722-
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
723-
"""get 1d data
724-
725-
Returns
726-
------
727-
data_1d: pd.DataFrame
728-
set(data_1d.columns) - set([self._date_field_name, self._symbol_field_name, "paused", "volume", "factor"]) == {}
729-
730-
"""
731-
return self._all_1d_data[
732-
(self._all_1d_data[self._symbol_field_name] == symbol.upper())
733-
& (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start))
734-
& (self._all_1d_data[self._date_field_name] < pd.Timestamp(end))
735-
]
736-
737-
738738
class Run(BaseRun):
739739
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN):
740740
"""
@@ -811,7 +811,13 @@ def download_data(
811811
max_collector_count, delay, start, end, self.interval, check_data_length, limit_nums
812812
)
813813

814-
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", end_date: str = None):
814+
def normalize_data(
815+
self,
816+
date_field_name: str = "date",
817+
symbol_field_name: str = "symbol",
818+
end_date: str = None,
819+
qlib_data_1d_dir: str = None,
820+
):
815821
"""normalize data
816822
817823
Parameters
@@ -822,12 +828,29 @@ def normalize_data(self, date_field_name: str = "date", symbol_field_name: str =
822828
symbol field name, default symbol
823829
end_date: str
824830
if not None, normalize the last date saved (including end_date); if None, it will ignore this parameter; by default None
831+
qlib_data_1d_dir: str
832+
if interval==1min, qlib_data_1d_dir cannot be None, normalize 1min needs to use 1d data;
833+
834+
qlib_data_1d can be obtained like this:
835+
$ python scripts/get_data.py qlilb_data --target_dir <qlib_data_1d_dir> --interval 1d
836+
$ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <qlib_data_1d_dir> --trading_date 2021-06-01
837+
or:
838+
download 1d data, reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#1d-from-yahoo
825839
826840
Examples
827841
---------
828842
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region cn --interval 1d
843+
$ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min
829844
"""
830-
super(Run, self).normalize_data(date_field_name, symbol_field_name, end_date=end_date)
845+
if self.interval.lower() == "1min":
846+
if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists():
847+
# TODO: add reference url
848+
raise ValueError(
849+
"If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir <user qlib 1d data >, Reference: "
850+
)
851+
super(Run, self).normalize_data(
852+
date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir
853+
)
831854

832855
def normalize_data_1d_extend(
833856
self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol"
@@ -873,36 +896,6 @@ def normalize_data_1d_extend(
873896
)
874897
yc.normalize()
875898

876-
def normalize_data_1min_cn_offline(
877-
self, qlib_data_1d_dir: str, date_field_name: str = "date", symbol_field_name: str = "symbol"
878-
):
879-
"""Normalised to 1min using local 1d data
880-
881-
Parameters
882-
----------
883-
qlib_data_1d_dir: str
884-
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
885-
date_field_name: str
886-
date field name, default date
887-
symbol_field_name: str
888-
symbol field name, default symbol
889-
890-
Examples
891-
---------
892-
$ python collector.py normalize_data_1min_cn_offline --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min
893-
"""
894-
_class = getattr(self._cur_module, f"{self.normalize_class_name}Offline")
895-
yc = Normalize(
896-
source_dir=self.source_dir,
897-
target_dir=self.normalize_dir,
898-
normalize_class=_class,
899-
max_workers=self.max_workers,
900-
date_field_name=date_field_name,
901-
symbol_field_name=symbol_field_name,
902-
qlib_data_1d_dir=qlib_data_1d_dir,
903-
)
904-
yc.normalize()
905-
906899
def download_today_data(
907900
self,
908901
max_collector_count=2,
@@ -987,15 +980,15 @@ def update_data_to_bin(self, qlib_data_1d_dir: str, trading_date: str = None, en
987980
end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
988981

989982
# download qlib 1d data
990-
qlib_data_1d_dir = Path(qlib_data_1d_dir).expanduser().resolve()
983+
qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve())
991984
if not exists_qlib_data(qlib_data_1d_dir):
992985
GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region)
993986

994987
# download data from yahoo
995988
self.download_data(delay=1, start=trading_date, end=end_date, check_data_length=1)
996989

997990
# normalize data
998-
self.normalize_data_1d_extend(str(qlib_data_1d_dir))
991+
self.normalize_data_1d_extend(qlib_data_1d_dir)
999992

1000993
# dump bin
1001994
_dump = DumpDataUpdate(

0 commit comments

Comments
 (0)