Skip to content

Commit

Permalink
🐞 fix: lifecycle函数未能触发问题
Browse files Browse the repository at this point in the history
  • Loading branch information
aaron-yang-biz committed Nov 21, 2023
1 parent 0c5e251 commit 6404a56
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 16 deletions.
4 changes: 3 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# History

#
# 2.0.0-alpha78
* backtest中捕获异常时,如果是TradeError类型,打印该对象自带的stack
## 2.0.0-alpha77
* strategy增加lifecycle
* 保留最后一个回测周期仅供交易使用,不调用`predict`
Expand Down
9 changes: 5 additions & 4 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,10 @@ sma = SMAStrategy(
start=datetime.date(2023, 2, 3),
end=datetime.date(2023, 4, 28),
frame_type=FrameType.DAY,
warmup_period = 20
)

await sma.backtest(prefetch_stocks=["600000.XSHG", min_bars=20])
await sma.backtest(prefetch_stocks=["600000.XSHG"])
```
在回测时,必须要指定`is_backtest=True``start`, `end`参数。
### 3.2. 回测报告
Expand Down Expand Up @@ -346,7 +347,7 @@ await sma.plot_metrics(indicator)

在回测中,可以使用主周期的数据预取,以加快回测速度。工作原理如下:

如果策略在调用`backtest`时传入了`prefetch_stocks``min_bars`参数,则`backtest`将会在回测之前,预取从[start - min_bars * frame_type, end]间的portfolio行情数据,并在每次调用`predict`方法时,通过`barss`参数,将[start - min_bars * frame_type, start + i * frame_type]间的数据传给`predict`方法。传入的数据已进行前复权。
如果策略指定了`warmup_period`,并在调用`backtest`时传入了`prefetch_stocks`参数,则`backtest`将会在回测之前,预取从[start - warmup_period * frame_type, end]间的portfolio行情数据,并在每次调用`predict`方法时,通过`barss`参数,将[start - warmup_period * frame_type, start + i * frame_type]间的数据传给`predict`方法。传入的数据已进行前复权。

如果在回测过程中,需要偷看未来数据,可以使用peek方法。

Expand All @@ -370,8 +371,8 @@ sec = "600000.XSHG"
start = datetime.date(2022, 1, 4)
end = datetime.date(2023, 1,1)

sma = SMAStrategy(sec, url=cfg.backtest.url, is_backtest=True, start=start, end=end, frame_type=FrameType.DAY)
await sma.backtest(portfolio=[sec], min_bars=10, stop_on_error=False)
sma = SMAStrategy(sec, url=cfg.backtest.url, is_backtest=True, start=start, end=end, frame_type=FrameType.DAY, warmup_period=10)
await sma.backtest(portfolio=[sec], stop_on_error=False)
await sma.plot_metrics(sma.indicators)

```
Expand Down
22 changes: 14 additions & 8 deletions omicron/strategy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import jqdatasdk as jq
import numpy as np
import pandas as pd
import traderclient
from coretypes import BarsArray, Frame, FrameType
from coretypes.errors.trade import TradeError
from deprecation import deprecated
from traderclient import TraderClient

Expand Down Expand Up @@ -157,32 +157,36 @@ async def backtest(self, stop_on_error: bool = True, **kwargs):
tf.get_frames(self.bs.start, end_, self._frame_type) # type: ignore
):
barss = self._next()
day_barss = barss if self._frame_type == FrameType.DAY else None
frame_ = converter(frame)

prev_frame = tf.shift(frame_, -1, self._frame_type)
next_frame = tf.shift(frame_, 1, self._frame_type)

# new trading day start
if (not intra_day and prev_frame > frame_) or (
if (not intra_day and prev_frame < frame_) or (
intra_day and prev_frame.date() < frame_.date()
):
await self.before_trade(frame_)
await self.before_trade(frame_, day_barss)

logger.debug("%sth iteration", i, date=frame_)
try:
await self.predict(
frame_, self._frame_type, i, barss=barss, **kwargs # type: ignore
)
except Exception as e:
logger.exception(e)
if isinstance(e, TradeError):
logger.warning("call stack is:\n%s", e.stack)
else:
logger.exception(e)
if stop_on_error:
raise e

# trading day ends
if (not intra_day and next_frame > frame_) or (
intra_day and next_frame.date() > frame_.date()
):
await self.after_trade(frame_)
await self.after_trade(frame_, day_barss)

self.broker.stop_backtest()

Expand Down Expand Up @@ -319,19 +323,21 @@ async def before_start(self):
else:
logger.info("BEFORE_START: %s", self.name)

async def before_trade(self, date: datetime.date):
async def before_trade(self, date: datetime.date, barss: Optional[Dict[str, BarsArray]]=None):
"""每日开盘前的准备工作
Args:
date: 日期。在回测中为回测当日日期,在实盘中为系统日期
barss: 如果主周期为日线,且支持预取,则会将预取的barss传入
"""
logger.debug("BEFORE_TRADE: %s", self.name, date=date)

async def after_trade(self, date: Frame):
async def after_trade(self, date: Frame, barss: Optional[Dict[str, BarsArray]]=None):
"""每日收盘后的收尾工作
Args:
date: 日期。在回测中为回测当日日期,在实盘中为系统日期
barss: 如果主周期为日线,且支持预取,则会将预取的barss传入
"""
logger.debug("AFTER_TRADE: %s", self.name, date=date)

Expand Down Expand Up @@ -361,7 +367,7 @@ async def predict(
frame: 当前时间帧
frame_type: 处理的数据主周期
i: 当前时间离回测起始的单位数
barss: 如果调用`backtest`时传入了`portfolio`及`min_bars`参数,则`backtest`将会在回测之前,预取从[start - min_bars * frame_type, end]间的portfolio行情数据,并在每次调用`predict`方法时,通过`barss`参数,将[start - min_bars * frame_type, start + i * frame_type]间的数据传给`predict`方法。传入的数据已进行前复权。
barss: 如果调用`backtest`时传入了`portfolio`及参数,则`backtest`将会在回测之前,预取从[start - warmup_period * frame_type, end]间的portfolio行情数据,并在每次调用`predict`方法时,通过`barss`参数,将[start - warmup_period * frame_type, start + i * frame_type]间的数据传给`predict`方法。传入的数据已进行前复权。
Keyword Args: 在`backtest`方法中的传入的kwargs参数将被透传到此方法中。
"""
Expand Down
2 changes: 1 addition & 1 deletion omicron/strategy/sma.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def predict(
self, frame: Frame, frame_type: FrameType, i: int, barss, **kwargs
):
if barss is None:
raise ValueError("please specify `prefetch_stocks` and `min_bars`")
raise ValueError("please specify `prefetch_stocks`")

bars: Union[BarsArray, None] = barss.get(self._sec)
if bars is None:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ name = "zillionare-omicron"
packages = [
{include = "omicron"}
]
version = "2.0.0a77"
version = "2.0.0a78"
description = "Core Library for Zillionare"
authors = ["jieyu <[email protected]>"]
license = "MIT"
Expand Down
3 changes: 2 additions & 1 deletion tests/strategy/test_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ async def test_sma_strategy(self, mc1, mc2, mc3):
start=datetime.date(2023, 2, 3),
end=datetime.date(2023, 4, 28),
frame_type=FrameType.DAY,
warmup_period = 20
)

# setup the mock
Expand All @@ -113,4 +114,4 @@ async def test_sma_strategy(self, mc1, mc2, mc3):

# no exception is ok
with mock.patch.object(omicron.models.stock.Stock, "get_bars", self.get_bars):
await sma.backtest(stop_on_error=True, portfolio=[code], min_bars=10)
await sma.backtest(stop_on_error=True, portfolio=[code])

0 comments on commit 6404a56

Please sign in to comment.