Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add extension point for dataset-loader associations #2246

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions tests/pipeline/test_dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from zipline.pipeline import USEquityPricingLoader
from zipline.pipeline.data import (
Column,
DataSet,
BoundColumn,
USEquityPricing,
)
from zipline.pipeline.dispatcher import (
PipelineDispatcher,
clear_all_associations,
)
from zipline.pipeline.loaders.base import PipelineLoader
from zipline.pipeline.sentinels import NotSpecified
from zipline.testing import ZiplineTestCase
from zipline.testing.fixtures import WithAdjustmentReader
from zipline.testing.predicates import (
assert_raises_str,
assert_equal,
)
from zipline.utils.numpy_utils import float64_dtype


class FakeDataSet(DataSet):
test_col = Column(float64_dtype)


class FakeColumn(BoundColumn):
pass


class FakePipelineLoader(PipelineLoader):

def load_adjusted_array(self, columns, dates, assets, mask):
pass


class UnrelatedType(object):
pass


class PipelineDispatcherTestCase(WithAdjustmentReader, ZiplineTestCase):

@classmethod
def init_class_fixtures(cls):
super(PipelineDispatcherTestCase, cls).init_class_fixtures()
cls.default_pipeline_loader = USEquityPricingLoader(
cls.bcolz_equity_daily_bar_reader,
cls.adjustment_reader,
)

cls.add_class_callback(clear_all_associations)

def test_load_not_registered(self):
fake_col_instance = FakeColumn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked carefully at these tests yet, but in general, it's not expected that anyone should ever construct BoundColumn instances explicitly. The usual way to get a bound column is to do:

class SomeDataSet(DataSet):
    column = Column(...)

SomeDataSet.column  # This evaluates to a BoundColumn because Column implements the descriptor protocol.

float64_dtype,
NotSpecified,
FakeDataSet,
'test',
None,
{},
)
fake_pl_instance = FakePipelineLoader()
pipeline_dispatcher = PipelineDispatcher(
column_loaders={fake_col_instance: fake_pl_instance}
)

expected_dict = {fake_col_instance: fake_pl_instance}
assert_equal(pipeline_dispatcher.column_loaders, expected_dict)

msg = "No pipeline loader registered for %s" % USEquityPricing.close
with assert_raises_str(LookupError, msg):
pipeline_dispatcher(USEquityPricing.close)

def test_register_unrelated_type(self):
pipeline_dispatcher = PipelineDispatcher()
fake_pl_instance = FakePipelineLoader()

msg = "Data provided is neither a BoundColumn nor a DataSet"
with assert_raises_str(TypeError, msg):
pipeline_dispatcher.register(UnrelatedType, fake_pl_instance)

def test_passive_registration(self):
pipeline_dispatcher = PipelineDispatcher()
assert_equal(pipeline_dispatcher.column_loaders, {})

# imitate user registering a custom pipeline loader first
custom_loader = FakePipelineLoader()
pipeline_dispatcher.register(USEquityPricing.close, custom_loader)
expected_dict = {USEquityPricing.close: custom_loader}
assert_equal(pipeline_dispatcher.column_loaders, expected_dict)

# now check that trying to register something else won't change it
pipeline_dispatcher.register(
USEquityPricing.close, self.default_pipeline_loader
)
assert_equal(pipeline_dispatcher.column_loaders, expected_dict)

def test_normal_ops(self):
fake_loader_instance = FakePipelineLoader()
fake_col_instance = FakeColumn(
float64_dtype,
NotSpecified,
FakeDataSet,
'test',
None,
{},
)
pipeline_dispatcher = PipelineDispatcher(
column_loaders={
fake_col_instance: fake_loader_instance
},
dataset_loaders={
FakeDataSet: fake_loader_instance
}
)
expected_dict = {
fake_col_instance: fake_loader_instance,
FakeDataSet.test_col: fake_loader_instance,
}
assert_equal(pipeline_dispatcher.column_loaders, expected_dict)

pipeline_dispatcher.register(
USEquityPricing.close, fake_loader_instance
)
expected_dict = {
fake_col_instance: fake_loader_instance,
FakeDataSet.test_col: fake_loader_instance,
USEquityPricing.close: fake_loader_instance,
}
assert_equal(pipeline_dispatcher.column_loaders, expected_dict)

assert_equal(
pipeline_dispatcher(fake_col_instance), fake_loader_instance
)
assert_equal(
pipeline_dispatcher(FakeDataSet.test_col), fake_loader_instance
)
assert_equal(
pipeline_dispatcher(USEquityPricing.close), fake_loader_instance
)
3 changes: 3 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from zipline import examples
from zipline.data.bundles import register, unregister
from zipline.pipeline.dispatcher import clear_all_associations
from zipline.testing import test_resource_path
from zipline.testing.fixtures import WithTmpDir, ZiplineTestCase
from zipline.testing.predicates import assert_equal
Expand Down Expand Up @@ -63,6 +64,8 @@ def init_class_fixtures(cls):
)
)

cls.add_class_callback(clear_all_associations)

@parameterized.expand(sorted(examples.EXAMPLE_MODULES))
def test_example(self, example_name):
actual_perf = examples.run_example(
Expand Down
3 changes: 3 additions & 0 deletions zipline/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .graph import ExecutionPlan, TermGraph
from .pipeline import Pipeline
from .loaders import USEquityPricingLoader
from .dispatcher import global_pipeline_dispatcher, register_pipeline_loader


def engine_from_files(daily_bar_path,
Expand Down Expand Up @@ -60,4 +61,6 @@ def engine_from_files(daily_bar_path,
'SimplePipelineEngine',
'Term',
'TermGraph',
'global_pipeline_dispatcher',
'register_pipeline_loader',
)
63 changes: 63 additions & 0 deletions zipline/pipeline/dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from zipline.pipeline.data import BoundColumn, DataSet
from zipline.pipeline.loaders.base import PipelineLoader
from zipline.utils.compat import mappingproxy


class PipelineDispatcher(object):
"""Helper class for building a dispatching function for a PipelineLoader.

Parameters
----------
column_loaders : dict[BoundColumn -> PipelineLoader]
Map from columns to pipeline loader for those columns.
dataset_loaders : dict[DataSet -> PipelineLoader]
Map from datasets to pipeline loader for those datasets.
"""
def __init__(self, column_loaders=None, dataset_loaders=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to segregate these into column_loaders and dataset_loaders? We could just accept loaders with a dictionary from either BoundColumn or DataSet -> loader.

self._column_loaders = column_loaders if column_loaders \
is not None else {}
self.column_loaders = mappingproxy(self._column_loaders)
if dataset_loaders is not None:
for dataset, pl in dataset_loaders.items():
self.register(dataset, pl)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tend to try to avoid calling methods on self in constructors. Generally when editing a method, most people assume that the class is fully-constructed by the time the method is called (for example, they assume that all attributes will be set). Calling a method in a constructor can violate that assumption.


def __call__(self, column):
if column in self._column_loaders:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's generally more idiomatic in python to write this as:

try:
    return self._column_loaders[column]
except KeyError:
    raise LookupError(...)

c.f. https://blogs.msdn.microsoft.com/pythonengineering/2016/06/29/idiomatic-python-eafp-versus-lbyl/, for example

return self._column_loaders[column]
else:
raise LookupError("No pipeline loader registered for %s" % column)

def register(self, data, pl):
"""Register a given PipelineLoader to a column or columns of a dataset

Parameters
----------
data : BoundColumn or DataSet
The column or dataset for which to register the PipelineLoader
pl : PipelineLoader
The PipelineLoader to register for the column or dataset columns
"""
assert isinstance(pl, PipelineLoader)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two notes on this:

  1. I don't think we need to be doing type checking on these inputs. A pipeline loader can be any object that implements load_adjusted_array with the correct signature. Forcing all the loaders to be subclasses of a base class isn't necessary. If a user wants to validate that their class implements the correct methods, that might make sense, but I had actually forgotten that the PipelineLoader base class existed, and we definitely have loaders downstream that don't subclass that base class.
  2. It's almost never a good idea to use assert for validating user input. assert statements get disabled if you run python with -O2. They're meant to be used for sanity checks in the internal logic of your program, not for verifying that a value received from a user matches some contract. Copy/pasting from a reply that I wrote elsewhere earlier this summer:

When Should You Use Assertions?

Generally speaking, you should use assertions for things that indicate that
there's a bug in your program or library that's severe enough that you'd prefer
for the program to crash immediately with a useful error rather than
continuing. In particular, you should only use assertions for things that you
have reason to believe to be true.

That generally means that you should use assertions for things like
preconditions and postconditions of functions that are internal to the library,
and that you shouldn't use assertions for things like validating user input
(whether or not the user passes you valid inputs isn't under your control, so
it can't be indicative of a bug in your part of the program).

A common scenario where assertions are useful is to have a user-facing API
layer that validates that input is well formed that then passes the known-good
input into internal library functions. You might include assertions in the
internal library functions to ensure that we don't accidentally call the
internal function on un-validated data, e.g.:

def user_facing_function(possibly_bad_data):
    good_data = validate(possibly_bad_data)
    return _internal_function(good_data)

def _internal_function(good_data):
    assert is_good(good_data), "good_data isn't good. Did you forget to validate?"
    # do stuff


# make it so that in either case nothing will happen if the column is
# already registered, allowing users to register their own loaders
# early on in extensions
if isinstance(data, BoundColumn):
if data not in self._column_loaders:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems dangerous to me to silently ignore conflicts if a loader gets registered for a dataset multiple times. I would expect this to be an error, and if not, I would have expected the later registration to win.

Independently of the above, dict.setdefault is a built-in way to write a key to a dictionary only if it's not present.

self._column_loaders[data] = pl
elif issubclass(data, DataSet):
for c in data.columns:
if c not in self._column_loaders:
self._column_loaders[c] = pl
else:
raise TypeError("Data provided is neither a BoundColumn "
"nor a DataSet")

def clear(self):
"""Unregisters all dataset-loader associations"""
self._column_loaders.clear()


global_pipeline_dispatcher = PipelineDispatcher()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a global registry for this? In general, I think we should avoid the pattern of having a mutable global registry of stuff unless it's absolutely necessary.

register_pipeline_loader = global_pipeline_dispatcher.register
clear_all_associations = global_pipeline_dispatcher.clear
16 changes: 11 additions & 5 deletions zipline/utils/run_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings

import click

try:
from pygments import highlight
from pygments.lexers import PythonLexer
Expand All @@ -21,6 +22,10 @@
from zipline.data.data_portal import DataPortal
from zipline.finance import metrics
from zipline.finance.trading import TradingEnvironment
from zipline.pipeline import (
register_pipeline_loader,
global_pipeline_dispatcher,
)
from zipline.pipeline.data import USEquityPricing
from zipline.pipeline.loaders import USEquityPricingLoader
from zipline.utils.factory import create_simulation_parameters
Expand Down Expand Up @@ -166,12 +171,13 @@ def _run(handle_data,
bundle_data.adjustment_reader,
)

# we register our default loader last, after any loaders from users
# have been registered via extensions
register_pipeline_loader(USEquityPricing, pipeline_loader)

def choose_loader(column):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to create a new function here anymore. global_pipeline_dispatcher is already a callable with the right signature.

if column in USEquityPricing.columns:
return pipeline_loader
raise ValueError(
"No PipelineLoader registered for column %s." % column
)
return global_pipeline_dispatcher(column)

else:
env = TradingEnvironment(environ=environ)
choose_loader = None
Expand Down