Skip to content

Commit

Permalink
Chore: Make release 1.1.4
Browse files Browse the repository at this point in the history
  • Loading branch information
martinroberson committed Aug 27, 2024
1 parent a8bdfea commit 4880552
Show file tree
Hide file tree
Showing 14 changed files with 401 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ def legs_decoder(data: Any):


def legs_encoder(data: Iterable[Instrument]):
return [i.as_dict() for i in data]
return [i.to_dict() for i in data]
24 changes: 22 additions & 2 deletions gs_quant/backtests/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class AddTradeAction(Action):
:param trade_duration: an instrument attribute eg. 'expiration_date' or a date or a tenor or timedelta
if left as None the
trade will be added for all future dates
can also specify 'next schedule' in order to exit at the next periodic trigger date
:param name: optional additional name to the priceable name
:param transaction_cost: optional a cash amount paid for each transaction, paid on both enter and exit
"""
Expand Down Expand Up @@ -139,11 +140,12 @@ def dated_priceables(self):
return self._dated_priceables


AddTradeActionInfo = namedtuple('AddTradeActionInfo', 'scaling')
AddTradeActionInfo = namedtuple('AddTradeActionInfo', ['scaling', 'next_schedule'])
EnterPositionQuantityScaledActionInfo = namedtuple('EnterPositionQuantityScaledActionInfo', 'not_applicable')
HedgeActionInfo = namedtuple('HedgeActionInfo', 'not_applicable')
HedgeActionInfo = namedtuple('HedgeActionInfo', 'next_schedule')
ExitTradeActionInfo = namedtuple('ExitTradeActionInfo', 'not_applicable')
RebalanceActionInfo = namedtuple('RebalanceActionInfo', 'not_applicable')
AddScaledTradeActionInfo = namedtuple('AddScaledActionInfo', 'next_schedule')


@dataclass_json
Expand All @@ -158,6 +160,7 @@ class AddScaledTradeAction(Action):
:param trade_duration: an instrument attribute eg. 'expiration_date' or a date or a tenor or timedelta
if left as None the
trade will be added for all future dates
can also specify 'next schedule' in order to exit at the next periodic trigger date
:param name: optional additional name to the priceable name
:param scaling_type: the type of scaling we are doing
:param scaling_risk: if the scaling type is a measure then this is the definition of the measure
Expand Down Expand Up @@ -263,6 +266,23 @@ def __post_init__(self):
@dataclass_json
@dataclass
class HedgeAction(Action):

"""
create an action which adds a hedge trade when triggered. This trade will be scaled to hedge the risk
specified. The trades are resolved on the trigger date (state) and
last until the trade_duration if specified or for all future dates if not.
:param risk: a risk measure which should be hedged
:param priceables: a priceable or a list of pricables these should have sensitivity to the risk.
:param trade_duration: an instrument attribute eg. 'expiration_date' or a date or a tenor or timedelta
if left as None the
trade will be added for all future dates
can also specify 'next schedule' in order to exit at the next periodic trigger date
:param name: optional additional name to the priceable name
:param transaction_cost: optional a transaction cost model, paid on both enter and exit
:param risk_transformation: optional a Transformer which will be applied to the raw risk numbers before hedging
:param holiday_calendar: optional an iterable list of holiday dates
"""

risk: RiskMeasure = field(default=None, metadata=config(decoder=decode_risk_measure,
encoder=encode_risk_measure))
priceables: Optional[Priceable] = field(default=None, metadata=config(decoder=decode_named_instrument,
Expand Down
6 changes: 5 additions & 1 deletion gs_quant/backtests/backtest_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def make_list(thing):
final_date_cache = {}


def get_final_date(inst, create_date, duration, holiday_calendar=None):
def get_final_date(inst, create_date, duration, holiday_calendar=None, trigger_info=None):
global final_date_cache
cache_key = (inst, create_date, duration, holiday_calendar)
if cache_key in final_date_cache:
Expand All @@ -58,6 +58,10 @@ def get_final_date(inst, create_date, duration, holiday_calendar=None):
if hasattr(inst, str(duration)):
final_date_cache[cache_key] = getattr(inst, str(duration))
return getattr(inst, str(duration))
if str(duration).lower() == 'next schedule':
if hasattr(trigger_info, 'next_schedule'):
return trigger_info.next_schedule or dt.date.max
raise RuntimeError('Next schedule not supported by action')

final_date_cache[cache_key] = RelativeDate(duration, create_date).apply_rule(holiday_calendar=holiday_calendar)
return final_date_cache[cache_key]
Expand Down
56 changes: 36 additions & 20 deletions gs_quant/backtests/generic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@

from gs_quant import risk
from gs_quant.backtests.action_handler import ActionHandlerBaseFactory, ActionHandler
from gs_quant.backtests.actions import Action, AddTradeAction, HedgeAction, EnterPositionQuantityScaledAction, \
AddTradeActionInfo, HedgeActionInfo, ExitTradeAction, ExitTradeActionInfo, EnterPositionQuantityScaledActionInfo, \
RebalanceAction, RebalanceActionInfo, ExitAllPositionsAction, AddScaledTradeAction, ScalingActionType
from gs_quant.backtests.actions import (Action, AddTradeAction, HedgeAction, EnterPositionQuantityScaledAction,
AddTradeActionInfo, HedgeActionInfo, ExitTradeAction, ExitTradeActionInfo,
EnterPositionQuantityScaledActionInfo, RebalanceAction, RebalanceActionInfo,
ExitAllPositionsAction, AddScaledTradeAction, ScalingActionType,
AddScaledTradeActionInfo)
from gs_quant.backtests.backtest_engine import BacktestBaseEngine
from gs_quant.backtests.backtest_objects import BackTest, ScalingPortfolio, CashPayment, Hedge
from gs_quant.backtests.backtest_utils import make_list, CalcType, get_final_date
from gs_quant.common import AssetClass
from gs_quant.common import ParameterisedRiskMeasure, RiskMeasure
from gs_quant.context_base import nullcontext
from gs_quant.datetime.relative_date import RelativeDateSchedule
Expand All @@ -39,7 +42,6 @@
from gs_quant.risk import Price
from gs_quant.risk.results import PortfolioRiskResult
from gs_quant.target.backtests import BacktestTradingQuantityType
from gs_quant.common import AssetClass
from gs_quant.target.measures import ResolvedInstrumentValues
from gs_quant.tracing import Tracer

Expand Down Expand Up @@ -74,7 +76,7 @@ def _raise_order(self,
final_orders = {}
for d, p in orders.items():
new_port = Portfolio([t.clone(name=f'{t.name}_{d}') for t in p[0].result()])
final_orders[d] = new_port.scale(None if p[1] is None else p[1].scaling, in_place=False)
final_orders[d] = (new_port.scale(None if p[1] is None else p[1].scaling, in_place=False), p[1])

return final_orders

Expand All @@ -86,12 +88,13 @@ def apply_action(self,
orders = self._raise_order(state, trigger_info)

# record entry and unwind cashflows
for create_date, portfolio in orders.items():
for create_date, (portfolio, info) in orders.items():
for inst in portfolio.all_instruments:
backtest.cash_payments[create_date].append(CashPayment(inst, effective_date=create_date, direction=-1))
backtest.transaction_costs[create_date] -= self.action.transaction_cost.get_cost(create_date, backtest,
trigger_info, inst)
final_date = get_final_date(inst, create_date, self.action.trade_duration, self.action.holiday_calendar)
final_date = get_final_date(inst, create_date, self.action.trade_duration, self.action.holiday_calendar,
info)
backtest.cash_payments[final_date].append(CashPayment(inst, effective_date=final_date))
backtest.transaction_costs[final_date] -= self.action.transaction_cost.get_cost(final_date,
backtest,
Expand Down Expand Up @@ -169,9 +172,8 @@ def _scale_order(self, orders, daily_risk, price_measure):
raise RuntimeError(f'Scaling Type {self.action.scaling_type} not supported by engine')

def _raise_order(self,
state: Union[date, Iterable[date]],
state_list: Iterable[date],
price_measure: RiskMeasure):
state_list = make_list(state)
orders = {}
order_valuations = (ResolvedInstrumentValues,)
if self.action.scaling_type == ScalingActionType.risk_measure:
Expand Down Expand Up @@ -201,18 +203,24 @@ def _raise_order(self,
def apply_action(self,
state: Union[date, Iterable[date]],
backtest: BackTest,
trigger_info: Optional[Union[EnterPositionQuantityScaledActionInfo,
Iterable[EnterPositionQuantityScaledActionInfo]]] = None):
trigger_info: Optional[Union[AddScaledTradeActionInfo,
Iterable[AddScaledTradeActionInfo]]] = None):

orders = self._raise_order(state, backtest.price_measure)
state_list = make_list(state)
if trigger_info is None or isinstance(trigger_info, AddScaledTradeActionInfo):
trigger_info = [trigger_info for _ in range(len(state_list))]
orders = self._raise_order(state_list, backtest.price_measure)
trigger_infos = dict(zip_longest(state_list, trigger_info))

# record entry and unwind cashflows
for create_date, portfolio in orders.items():
info = trigger_infos[create_date]
for inst in portfolio.all_instruments:
backtest.cash_payments[create_date].append(CashPayment(inst, effective_date=create_date, direction=-1))
backtest.transaction_costs[create_date] -= self.action.transaction_cost.get_cost(create_date, backtest,
trigger_info, inst)
final_date = get_final_date(inst, create_date, self.action.trade_duration, self.action.holiday_calendar)
final_date = get_final_date(inst, create_date, self.action.trade_duration, self.action.holiday_calendar,
info)
backtest.cash_payments[final_date].append(CashPayment(inst, effective_date=final_date))
backtest.transaction_costs[final_date] -= self.action.transaction_cost.get_cost(final_date,
backtest,
Expand Down Expand Up @@ -372,19 +380,25 @@ def apply_action(self,
state: Union[date, Iterable[date]],
backtest: BackTest,
trigger_info: Optional[Union[HedgeActionInfo, Iterable[HedgeActionInfo]]] = None):
with HistoricalPricingContext(dates=make_list(state), csa_term=self.action.csa_term):
state_list = make_list(state)
if trigger_info is None or isinstance(trigger_info, HedgeActionInfo):
trigger_info = [trigger_info for _ in range(len(state_list))]
trigger_infos = dict(zip_longest(state_list, trigger_info))

with HistoricalPricingContext(dates=state_list, csa_term=self.action.csa_term):
backtest.calc_calls += 1
backtest.calculations += len(make_list(state))
backtest.calculations += len(state_list)
f = Portfolio(self.action.priceable).resolve(in_place=False)

for create_date, portfolio in f.result().items():
info = trigger_infos[create_date]
hedge_trade = portfolio.priceables[0]
hedge_trade.name = f'{hedge_trade.name}_{create_date.strftime("%Y-%m-%d")}'
if isinstance(hedge_trade, Portfolio):
for instrument in hedge_trade.all_instruments:
instrument.name = f'{hedge_trade.name}_{instrument.name}'
final_date = get_final_date(hedge_trade, create_date, self.action.trade_duration,
self.action.holiday_calendar)
self.action.holiday_calendar, info)
active_dates = [s for s in backtest.states if create_date <= s < final_date]

if len(active_dates):
Expand Down Expand Up @@ -711,10 +725,12 @@ def __run(self, strategy, start, end, frequency, states, risks, initial_value, r
self._price_semi_det_triggers(backtest, risks)

logger.info('Scaling semi-determ triggers and actions and calculating path dependent triggers and actions')
for d in strategy_pricing_dates:
with self._trace('Process date') as scope:
with self._trace('Process dates') as scope:
if scope:
scope.span.set_tag('dates.length', len(strategy_pricing_dates))
for d in strategy_pricing_dates:
if scope:
scope.span.set_tag('date', str(d))
scope.span.log_kv({'date': str(d)})
self._process_triggers_and_actions_for_date(d, strategy, backtest, risks)

with self._trace('Calc New Trades'):
Expand Down Expand Up @@ -795,7 +811,7 @@ def _price_semi_det_triggers(self, backtest, risks):
port = p.trade if isinstance(p.trade, Portfolio) else Portfolio([p.trade])
p.results = port.calc(tuple(risks))

def _process_triggers_and_actions_for_date(self, d, strategy, backtest, risks):
def _process_triggers_and_actions_for_date(self, d, strategy, backtest: BackTest, risks):
logger.debug(f'{d}: Processing triggers and actions')
# path dependent
for trigger in strategy.triggers:
Expand Down
22 changes: 13 additions & 9 deletions gs_quant/backtests/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,27 @@
under the License.
"""

from dataclasses import dataclass
from dataclasses_json import dataclass_json, config
from dataclasses import dataclass, field
from typing import Tuple, Optional, Union, Iterable

from gs_quant.backtests.triggers import *
from gs_quant.backtests.generic_engine import GenericEngine
from gs_quant.backtests.predefined_asset_engine import PredefinedAssetEngine
from gs_quant.backtests.equity_vol_engine import EquityVolEngine
from dataclasses_json import dataclass_json, config

from gs_quant.backtests.backtest_utils import make_list
from gs_quant.backtests.triggers import Trigger
from gs_quant.base import Priceable
from gs_quant.json_convertors import decode_named_instrument, encode_named_instrument, dc_decode

backtest_engines = [GenericEngine(), PredefinedAssetEngine(), EquityVolEngine()]

def _backtest_engines():
from gs_quant.backtests.equity_vol_engine import EquityVolEngine
from gs_quant.backtests.generic_engine import GenericEngine
from gs_quant.backtests.predefined_asset_engine import PredefinedAssetEngine
return [GenericEngine(), PredefinedAssetEngine(), EquityVolEngine()]


@dataclass_json
@dataclass
class Strategy(object):
class Strategy:
"""
A strategy object on which one may run a backtest
"""
Expand All @@ -54,4 +58,4 @@ def get_risks(self):
return risk_list

def get_available_engines(self):
return [engine for engine in backtest_engines if engine.supports_strategy(self)]
return [engine for engine in _backtest_engines() if engine.supports_strategy(self)]
9 changes: 8 additions & 1 deletion gs_quant/backtests/triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,14 @@ def get_trigger_times(self) -> [dt.date]:
def has_triggered(self, state: dt.date, backtest: BackTest = None) -> TriggerInfo:
if not self.trigger_dates:
self.get_trigger_times()
return TriggerInfo(state in self.trigger_dates)
if state in self.trigger_dates:
next_state = None
if self.trigger_dates.index(state) != len(self.trigger_dates) - 1:
next_state = self.trigger_dates[self.trigger_dates.index(state) + 1]
return TriggerInfo(True, {AddTradeAction: AddTradeActionInfo(scaling=None, next_schedule=next_state),
AddScaledTradeAction: AddScaledTradeActionInfo(next_schedule=next_state),
HedgeAction: HedgeActionInfo(next_schedule=next_state)})
return TriggerInfo(False)


@dataclass_json
Expand Down
Loading

0 comments on commit 4880552

Please sign in to comment.