-
Notifications
You must be signed in to change notification settings - Fork 4.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Add extension point for dataset-loader associations #2246
base: master
Are you sure you want to change the base?
Changes from 4 commits
7d41d4c
274a762
450b7c1
65f582f
cdbf1eb
aac863e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
from zipline.pipeline import USEquityPricingLoader | ||
from zipline.pipeline.data import ( | ||
Column, | ||
DataSet, | ||
BoundColumn, | ||
USEquityPricing, | ||
) | ||
from zipline.pipeline.dispatcher import ( | ||
PipelineDispatcher, | ||
clear_all_associations, | ||
) | ||
from zipline.pipeline.loaders.base import PipelineLoader | ||
from zipline.pipeline.sentinels import NotSpecified | ||
from zipline.testing import ZiplineTestCase | ||
from zipline.testing.fixtures import WithAdjustmentReader | ||
from zipline.testing.predicates import ( | ||
assert_raises_str, | ||
assert_equal, | ||
) | ||
from zipline.utils.numpy_utils import float64_dtype | ||
|
||
|
||
class FakeDataSet(DataSet): | ||
test_col = Column(float64_dtype) | ||
|
||
|
||
class FakeColumn(BoundColumn): | ||
pass | ||
|
||
|
||
class FakePipelineLoader(PipelineLoader): | ||
|
||
def load_adjusted_array(self, columns, dates, assets, mask): | ||
pass | ||
|
||
|
||
class UnrelatedType(object): | ||
pass | ||
|
||
|
||
class PipelineDispatcherTestCase(WithAdjustmentReader, ZiplineTestCase): | ||
|
||
@classmethod | ||
def init_class_fixtures(cls): | ||
super(PipelineDispatcherTestCase, cls).init_class_fixtures() | ||
cls.default_pipeline_loader = USEquityPricingLoader( | ||
cls.bcolz_equity_daily_bar_reader, | ||
cls.adjustment_reader, | ||
) | ||
|
||
cls.add_class_callback(clear_all_associations) | ||
|
||
def test_load_not_registered(self): | ||
fake_col_instance = FakeColumn( | ||
float64_dtype, | ||
NotSpecified, | ||
FakeDataSet, | ||
'test', | ||
None, | ||
{}, | ||
) | ||
fake_pl_instance = FakePipelineLoader() | ||
pipeline_dispatcher = PipelineDispatcher( | ||
column_loaders={fake_col_instance: fake_pl_instance} | ||
) | ||
|
||
expected_dict = {fake_col_instance: fake_pl_instance} | ||
assert_equal(pipeline_dispatcher.column_loaders, expected_dict) | ||
|
||
msg = "No pipeline loader registered for %s" % USEquityPricing.close | ||
with assert_raises_str(LookupError, msg): | ||
pipeline_dispatcher(USEquityPricing.close) | ||
|
||
def test_register_unrelated_type(self): | ||
pipeline_dispatcher = PipelineDispatcher() | ||
fake_pl_instance = FakePipelineLoader() | ||
|
||
msg = "Data provided is neither a BoundColumn nor a DataSet" | ||
with assert_raises_str(TypeError, msg): | ||
pipeline_dispatcher.register(UnrelatedType, fake_pl_instance) | ||
|
||
def test_passive_registration(self): | ||
pipeline_dispatcher = PipelineDispatcher() | ||
assert_equal(pipeline_dispatcher.column_loaders, {}) | ||
|
||
# imitate user registering a custom pipeline loader first | ||
custom_loader = FakePipelineLoader() | ||
pipeline_dispatcher.register(USEquityPricing.close, custom_loader) | ||
expected_dict = {USEquityPricing.close: custom_loader} | ||
assert_equal(pipeline_dispatcher.column_loaders, expected_dict) | ||
|
||
# now check that trying to register something else won't change it | ||
pipeline_dispatcher.register( | ||
USEquityPricing.close, self.default_pipeline_loader | ||
) | ||
assert_equal(pipeline_dispatcher.column_loaders, expected_dict) | ||
|
||
def test_normal_ops(self): | ||
fake_loader_instance = FakePipelineLoader() | ||
fake_col_instance = FakeColumn( | ||
float64_dtype, | ||
NotSpecified, | ||
FakeDataSet, | ||
'test', | ||
None, | ||
{}, | ||
) | ||
pipeline_dispatcher = PipelineDispatcher( | ||
column_loaders={ | ||
fake_col_instance: fake_loader_instance | ||
}, | ||
dataset_loaders={ | ||
FakeDataSet: fake_loader_instance | ||
} | ||
) | ||
expected_dict = { | ||
fake_col_instance: fake_loader_instance, | ||
FakeDataSet.test_col: fake_loader_instance, | ||
} | ||
assert_equal(pipeline_dispatcher.column_loaders, expected_dict) | ||
|
||
pipeline_dispatcher.register( | ||
USEquityPricing.close, fake_loader_instance | ||
) | ||
expected_dict = { | ||
fake_col_instance: fake_loader_instance, | ||
FakeDataSet.test_col: fake_loader_instance, | ||
USEquityPricing.close: fake_loader_instance, | ||
} | ||
assert_equal(pipeline_dispatcher.column_loaders, expected_dict) | ||
|
||
assert_equal( | ||
pipeline_dispatcher(fake_col_instance), fake_loader_instance | ||
) | ||
assert_equal( | ||
pipeline_dispatcher(FakeDataSet.test_col), fake_loader_instance | ||
) | ||
assert_equal( | ||
pipeline_dispatcher(USEquityPricing.close), fake_loader_instance | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from zipline.pipeline.data import BoundColumn, DataSet | ||
from zipline.pipeline.loaders.base import PipelineLoader | ||
from zipline.utils.compat import mappingproxy | ||
|
||
|
||
class PipelineDispatcher(object): | ||
"""Helper class for building a dispatching function for a PipelineLoader. | ||
|
||
Parameters | ||
---------- | ||
column_loaders : dict[BoundColumn -> PipelineLoader] | ||
Map from columns to pipeline loader for those columns. | ||
dataset_loaders : dict[DataSet -> PipelineLoader] | ||
Map from datasets to pipeline loader for those datasets. | ||
""" | ||
def __init__(self, column_loaders=None, dataset_loaders=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason to segregate these into column_loaders and dataset_loaders? We could just accept |
||
self._column_loaders = column_loaders if column_loaders \ | ||
is not None else {} | ||
self.column_loaders = mappingproxy(self._column_loaders) | ||
if dataset_loaders is not None: | ||
for dataset, pl in dataset_loaders.items(): | ||
self.register(dataset, pl) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tend to try to avoid calling methods on |
||
|
||
def __call__(self, column): | ||
if column in self._column_loaders: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's generally more idiomatic in python to write this as:
c.f. https://blogs.msdn.microsoft.com/pythonengineering/2016/06/29/idiomatic-python-eafp-versus-lbyl/, for example |
||
return self._column_loaders[column] | ||
else: | ||
raise LookupError("No pipeline loader registered for %s" % column) | ||
|
||
def register(self, data, pl): | ||
"""Register a given PipelineLoader to a column or columns of a dataset | ||
|
||
Parameters | ||
---------- | ||
data : BoundColumn or DataSet | ||
The column or dataset for which to register the PipelineLoader | ||
pl : PipelineLoader | ||
The PipelineLoader to register for the column or dataset columns | ||
""" | ||
assert isinstance(pl, PipelineLoader) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two notes on this:
When Should You Use Assertions?Generally speaking, you should use assertions for things that indicate that That generally means that you should use assertions for things like A common scenario where assertions are useful is to have a user-facing API
|
||
|
||
# make it so that in either case nothing will happen if the column is | ||
# already registered, allowing users to register their own loaders | ||
# early on in extensions | ||
if isinstance(data, BoundColumn): | ||
if data not in self._column_loaders: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems dangerous to me to silently ignore conflicts if a loader gets registered for a dataset multiple times. I would expect this to be an error, and if not, I would have expected the later registration to win. Independently of the above, |
||
self._column_loaders[data] = pl | ||
elif issubclass(data, DataSet): | ||
for c in data.columns: | ||
if c not in self._column_loaders: | ||
self._column_loaders[c] = pl | ||
else: | ||
raise TypeError("Data provided is neither a BoundColumn " | ||
"nor a DataSet") | ||
|
||
def clear(self): | ||
"""Unregisters all dataset-loader associations""" | ||
self._column_loaders.clear() | ||
|
||
|
||
global_pipeline_dispatcher = PipelineDispatcher() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need a global registry for this? In general, I think we should avoid the pattern of having a mutable global registry of stuff unless it's absolutely necessary. |
||
register_pipeline_loader = global_pipeline_dispatcher.register | ||
clear_all_associations = global_pipeline_dispatcher.clear |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
import warnings | ||
|
||
import click | ||
|
||
try: | ||
from pygments import highlight | ||
from pygments.lexers import PythonLexer | ||
|
@@ -21,6 +22,10 @@ | |
from zipline.data.data_portal import DataPortal | ||
from zipline.finance import metrics | ||
from zipline.finance.trading import TradingEnvironment | ||
from zipline.pipeline import ( | ||
register_pipeline_loader, | ||
global_pipeline_dispatcher, | ||
) | ||
from zipline.pipeline.data import USEquityPricing | ||
from zipline.pipeline.loaders import USEquityPricingLoader | ||
from zipline.utils.factory import create_simulation_parameters | ||
|
@@ -166,12 +171,13 @@ def _run(handle_data, | |
bundle_data.adjustment_reader, | ||
) | ||
|
||
# we register our default loader last, after any loaders from users | ||
# have been registered via extensions | ||
register_pipeline_loader(USEquityPricing, pipeline_loader) | ||
|
||
def choose_loader(column): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need to create a new function here anymore. |
||
if column in USEquityPricing.columns: | ||
return pipeline_loader | ||
raise ValueError( | ||
"No PipelineLoader registered for column %s." % column | ||
) | ||
return global_pipeline_dispatcher(column) | ||
|
||
else: | ||
env = TradingEnvironment(environ=environ) | ||
choose_loader = None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't looked carefully at these tests yet, but in general, it's not expected that anyone should ever construct BoundColumn instances explicitly. The usual way to get a bound column is to do: