Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add extension point for dataset-loader associations #2246

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions tests/pipeline/test_dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from zipline.pipeline.data import (
Column,
DataSet,
BoundColumn,
USEquityPricing,
)
from zipline.pipeline.dispatcher import PipelineDispatcher
from zipline.pipeline.loaders.base import PipelineLoader
from zipline.pipeline.sentinels import NotSpecified
from zipline.testing import ZiplineTestCase
from zipline.testing.predicates import (
assert_raises_str,
assert_equal,
)
from zipline.utils.numpy_utils import float64_dtype


class FakeDataSet(DataSet):
test_col = Column(float64_dtype)


class FakeColumn(BoundColumn):
pass


class FakePipelineLoader(PipelineLoader):

def load_adjusted_array(self, columns, dates, assets, mask):
pass


class UnrelatedType(object):
pass


class PipelineDispatcherTestCase(ZiplineTestCase):

def test_load_not_registered(self):
fake_col_instance = FakeColumn(
float64_dtype,
NotSpecified,
FakeDataSet,
'test',
None,
{},
)
fake_pl_instance = FakePipelineLoader()
pipeline_dispatcher = PipelineDispatcher(
{fake_col_instance: fake_pl_instance}
)

expected_dict = {fake_col_instance: fake_pl_instance}
assert_equal(pipeline_dispatcher._column_loaders, expected_dict)

msg = "No pipeline loader registered for %s" % USEquityPricing.close
with assert_raises_str(LookupError, msg):
pipeline_dispatcher(USEquityPricing.close)

def test_register_unrelated_type(self):
fake_pl_instance = FakePipelineLoader()

msg = "%s is neither a BoundColumn nor a DataSet" % UnrelatedType
with assert_raises_str(TypeError, msg):
PipelineDispatcher(
{UnrelatedType: fake_pl_instance}
)

def test_normal_ops(self):
fake_loader_instance = FakePipelineLoader()
fake_col_instance = FakeColumn(
float64_dtype,
NotSpecified,
FakeDataSet,
'test',
None,
{},
)
pipeline_dispatcher = PipelineDispatcher({
fake_col_instance: fake_loader_instance,
FakeDataSet: fake_loader_instance
})

expected_dict = {
fake_col_instance: fake_loader_instance,
FakeDataSet.test_col: fake_loader_instance,
}
assert_equal(pipeline_dispatcher._column_loaders, expected_dict)
assert_equal(
pipeline_dispatcher(fake_col_instance), fake_loader_instance
)
assert_equal(
pipeline_dispatcher(FakeDataSet.test_col), fake_loader_instance
)
2 changes: 2 additions & 0 deletions zipline/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .graph import ExecutionPlan, TermGraph
from .pipeline import Pipeline
from .loaders import USEquityPricingLoader
from .dispatcher import PipelineDispatcher


def engine_from_files(daily_bar_path,
Expand Down Expand Up @@ -60,4 +61,5 @@ def engine_from_files(daily_bar_path,
'SimplePipelineEngine',
'Term',
'TermGraph',
'PipelineDispatcher',
)
28 changes: 28 additions & 0 deletions zipline/pipeline/dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from zipline.pipeline.data import BoundColumn, DataSet


class PipelineDispatcher(object):
"""Helper class for building a dispatching function for a PipelineLoader.

Parameters
----------
loaders : dict[BoundColumn or DataSet -> PipelineLoader]
Map from columns or datasets to pipeline loader for those objects.
"""
def __init__(self, loaders):
self._column_loaders = {}
for data, pl in loaders.items():
if isinstance(data, BoundColumn):
self._column_loaders[data] = pl
elif issubclass(data, DataSet):
for c in data.columns:
self._column_loaders[c] = pl
else:
raise TypeError("%s is neither a BoundColumn "
"nor a DataSet" % data)

def __call__(self, column):
if column in self._column_loaders:
return self._column_loaders[column]
else:
raise LookupError("No pipeline loader registered for %s" % column)
30 changes: 17 additions & 13 deletions zipline/utils/run_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import sys
import warnings

try:
from pygments import highlight
from pygments.lexers import PythonLexer
Expand All @@ -18,6 +17,7 @@
from zipline.data.loader import load_market_data
from zipline.data.data_portal import DataPortal
from zipline.finance import metrics
from zipline.pipeline import PipelineDispatcher
from zipline.finance.trading import SimulationParameters
from zipline.pipeline.data import USEquityPricing
from zipline.pipeline.loaders import USEquityPricingLoader
Expand Down Expand Up @@ -72,6 +72,7 @@ def _run(handle_data,
metrics_set,
local_namespace,
environ,
pipeline_dispatcher,
blotter,
benchmark_returns):
"""Run a backtest for the given algorithm.
Expand Down Expand Up @@ -155,16 +156,14 @@ def _run(handle_data,
adjustment_reader=bundle_data.adjustment_reader,
)

pipeline_loader = USEquityPricingLoader(
bundle_data.equity_daily_bar_reader,
bundle_data.adjustment_reader,
)

def choose_loader(column):
if column in USEquityPricing.columns:
return pipeline_loader
raise ValueError(
"No PipelineLoader registered for column %s." % column
if pipeline_dispatcher is None:
# create the default dispatcher
pipeline_loader = USEquityPricingLoader(
bundle_data.equity_daily_bar_reader,
bundle_data.adjustment_reader,
)
pipeline_dispatcher = PipelineDispatcher(
{USEquityPricing: pipeline_loader}
)

if isinstance(metrics_set, six.string_types):
Expand All @@ -181,8 +180,8 @@ def choose_loader(column):

perf = TradingAlgorithm(
namespace=namespace,
get_pipeline_loader=pipeline_dispatcher,
data_portal=data,
get_pipeline_loader=choose_loader,
trading_calendar=trading_calendar,
sim_params=SimulationParameters(
start_session=start,
Expand Down Expand Up @@ -225,7 +224,7 @@ def load_extensions(default, extensions, strict, environ, reload=False):
----------
default : bool
Load the default exension (~/.zipline/extension.py)?
extension : iterable[str]
extensions : iterable[str]
The paths to the extensions to load. If the path ends in ``.py`` it is
treated as a script and executed. If it does not end in ``.py`` it is
treated as a module to be imported.
Expand Down Expand Up @@ -285,6 +284,7 @@ def run_algorithm(start,
extensions=(),
strict_extensions=True,
environ=os.environ,
pipeline_dispatcher=None,
blotter='default'):
"""
Run a trading algorithm.
Expand Down Expand Up @@ -338,6 +338,9 @@ def run_algorithm(start,
environ : mapping[str -> str], optional
The os environment to use. Many extensions use this to get parameters.
This defaults to ``os.environ``.
pipeline_dispatcher : PipelineDispatcher, optional
The pipeline dispatcher to use, which should contains any column-to-
loader associations necessary to run the trading algorithm
blotter : str or zipline.finance.blotter.Blotter, optional
Blotter to use with this algorithm. If passed as a string, we look for
a blotter construction function registered with
Expand Down Expand Up @@ -376,6 +379,7 @@ def run_algorithm(start,
metrics_set=metrics_set,
local_namespace=False,
environ=environ,
pipeline_dispatcher=pipeline_dispatcher,
blotter=blotter,
benchmark_returns=benchmark_returns,
)