From 7d41d4c50829ac26428931bda4219a53e497e8be Mon Sep 17 00:00:00 2001 From: Jacob Nazarenko Date: Wed, 11 Jul 2018 11:36:00 -0400 Subject: [PATCH 1/5] ENH: Create PipelineDispatcher class for column/loader associations --- zipline/pipeline/__init__.py | 3 ++ zipline/pipeline/dispatcher.py | 58 ++++++++++++++++++++++++++++++++++ zipline/utils/run_algo.py | 16 +++++++--- 3 files changed, 72 insertions(+), 5 deletions(-) create mode 100644 zipline/pipeline/dispatcher.py diff --git a/zipline/pipeline/__init__.py b/zipline/pipeline/__init__.py index a169256bb9..95b7200fe3 100644 --- a/zipline/pipeline/__init__.py +++ b/zipline/pipeline/__init__.py @@ -9,6 +9,7 @@ from .graph import ExecutionPlan, TermGraph from .pipeline import Pipeline from .loaders import USEquityPricingLoader +from .dispatcher import global_pipeline_dispatcher, register_pipeline_loader def engine_from_files(daily_bar_path, @@ -60,4 +61,6 @@ def engine_from_files(daily_bar_path, 'SimplePipelineEngine', 'Term', 'TermGraph', + 'global_pipeline_dispatcher', + 'register_pipeline_loader', ) diff --git a/zipline/pipeline/dispatcher.py b/zipline/pipeline/dispatcher.py new file mode 100644 index 0000000000..172355d1e8 --- /dev/null +++ b/zipline/pipeline/dispatcher.py @@ -0,0 +1,58 @@ +from zipline.pipeline.data import BoundColumn, DataSet +from zipline.pipeline.loaders.base import PipelineLoader +from zipline.utils.compat import mappingproxy + + +class PipelineDispatcher(object): + """Helper class for building a dispatching function for a PipelineLoader. + + Parameters + ---------- + column_loaders : dict[BoundColumn -> PipelineLoader] + Map from columns to pipeline loader for those columns. + dataset_loaders : dict[DataSet -> PipelineLoader] + Map from datasets to pipeline loader for those datasets. + """ + def __init__(self, column_loaders=None, dataset_loaders=None): + self._column_loaders = column_loaders if column_loaders \ + is not None else {} + self.column_loaders = mappingproxy(self._column_loaders) + if dataset_loaders is not None: + for dataset, pl in dataset_loaders: + self.register(dataset, pl) + + def __call__(self, column): + if column in self._column_loaders: + return self._column_loaders[column] + else: + raise LookupError("No pipeline loader registered for %s", column) + + def register(self, data, pl): + """Register a given PipelineLoader to a column or columns of a dataset + + Parameters + ---------- + data : BoundColumn or DataSet + The column or dataset for which to register the PipelineLoader + pl : PipelineLoader + The PipelineLoader to register for the column or dataset columns + """ + assert isinstance(pl, PipelineLoader) + + # make it so that in either case nothing will happen if the column is + # already registered, allowing users to register their own loaders + # early on in extensions + if isinstance(data, BoundColumn): + if data not in self._column_loaders: + self._column_loaders[data] = pl + elif issubclass(data, DataSet): + for c in data.columns: + if c not in self._column_loaders: + self._column_loaders[c] = pl + else: + raise TypeError("Data provided is neither a BoundColumn " + "nor a DataSet") + + +global_pipeline_dispatcher = PipelineDispatcher() +register_pipeline_loader = global_pipeline_dispatcher.register diff --git a/zipline/utils/run_algo.py b/zipline/utils/run_algo.py index cebd018a9d..aff04a7235 100644 --- a/zipline/utils/run_algo.py +++ b/zipline/utils/run_algo.py @@ -5,6 +5,7 @@ import warnings import click + try: from pygments import highlight from pygments.lexers import PythonLexer @@ -21,6 +22,10 @@ from zipline.data.data_portal import DataPortal from zipline.finance import metrics from zipline.finance.trading import TradingEnvironment +from zipline.pipeline import ( + register_pipeline_loader, + global_pipeline_dispatcher, +) from zipline.pipeline.data import USEquityPricing from zipline.pipeline.loaders import USEquityPricingLoader from zipline.utils.factory import create_simulation_parameters @@ -166,12 +171,13 @@ def _run(handle_data, bundle_data.adjustment_reader, ) + # we register our default loader last, after any loaders from users + # have been registered via extensions + register_pipeline_loader(USEquityPricing, pipeline_loader) + def choose_loader(column): - if column in USEquityPricing.columns: - return pipeline_loader - raise ValueError( - "No PipelineLoader registered for column %s." % column - ) + return global_pipeline_dispatcher(column) + else: env = TradingEnvironment(environ=environ) choose_loader = None From 274a762df663a3a1554c589c02647e502ec9dc54 Mon Sep 17 00:00:00 2001 From: Jacob Nazarenko Date: Wed, 11 Jul 2018 16:26:23 -0400 Subject: [PATCH 2/5] TST: Add tests for PipelineDispatcher --- tests/pipeline/test_dispatcher.py | 138 ++++++++++++++++++++++++++++++ zipline/pipeline/dispatcher.py | 4 +- 2 files changed, 140 insertions(+), 2 deletions(-) create mode 100644 tests/pipeline/test_dispatcher.py diff --git a/tests/pipeline/test_dispatcher.py b/tests/pipeline/test_dispatcher.py new file mode 100644 index 0000000000..5b042c747c --- /dev/null +++ b/tests/pipeline/test_dispatcher.py @@ -0,0 +1,138 @@ +import os + +from zipline.data import bundles +from zipline.pipeline import USEquityPricingLoader +from zipline.pipeline.data import ( + Column, + DataSet, + BoundColumn, + USEquityPricing, +) +from zipline.pipeline.dispatcher import PipelineDispatcher +from zipline.pipeline.loaders.base import PipelineLoader +from zipline.pipeline.sentinels import NotSpecified +from zipline.testing import ZiplineTestCase +from zipline.testing.fixtures import WithAdjustmentReader +from zipline.testing.predicates import ( + assert_raises_str, + assert_equal, +) +from zipline.utils.numpy_utils import float64_dtype + + +class FakeDataSet(DataSet): + test_col = Column(float64_dtype) + + +class FakeColumn(BoundColumn): + pass + + +class FakePipelineLoader(PipelineLoader): + + def load_adjusted_array(self, columns, dates, assets, mask): + pass + + +class UnrelatedType(object): + pass + + +class PipelineDispatcherTestCase(WithAdjustmentReader, ZiplineTestCase): + + @classmethod + def init_class_fixtures(cls): + super(PipelineDispatcherTestCase, cls).init_class_fixtures() + cls.default_pipeline_loader = USEquityPricingLoader( + cls.bcolz_equity_daily_bar_reader, + cls.adjustment_reader, + ) + + def test_load_not_registered(self): + fake_col_instance = FakeColumn( + float64_dtype, + NotSpecified, + FakeDataSet, + 'test', + None, + {}, + ) + fake_pl_instance = FakePipelineLoader() + pipeline_dispatcher = PipelineDispatcher( + column_loaders={fake_col_instance: fake_pl_instance} + ) + + expected_dict = {fake_col_instance: fake_pl_instance} + assert_equal(pipeline_dispatcher.column_loaders, expected_dict) + + msg = "No pipeline loader registered for %s" % USEquityPricing.close + with assert_raises_str(LookupError, msg): + pipeline_dispatcher(USEquityPricing.close) + + def test_register_unrelated_type(self): + pipeline_dispatcher = PipelineDispatcher() + fake_pl_instance = FakePipelineLoader() + + msg = "Data provided is neither a BoundColumn nor a DataSet" + with assert_raises_str(TypeError, msg): + pipeline_dispatcher.register(UnrelatedType, fake_pl_instance) + + def test_passive_registration(self): + pipeline_dispatcher = PipelineDispatcher() + assert_equal(pipeline_dispatcher.column_loaders, {}) + + # imitate user registering a custom pipeline loader first + custom_loader = FakePipelineLoader() + pipeline_dispatcher.register(USEquityPricing.close, custom_loader) + expected_dict = {USEquityPricing.close: custom_loader} + assert_equal(pipeline_dispatcher.column_loaders, expected_dict) + + # now check that trying to register something else won't change it + pipeline_dispatcher.register( + USEquityPricing.close, self.default_pipeline_loader + ) + assert_equal(pipeline_dispatcher.column_loaders, expected_dict) + + def test_normal_ops(self): + fake_loader_instance = FakePipelineLoader() + fake_col_instance = FakeColumn( + float64_dtype, + NotSpecified, + FakeDataSet, + 'test', + None, + {}, + ) + pipeline_dispatcher = PipelineDispatcher( + column_loaders={ + fake_col_instance: fake_loader_instance + }, + dataset_loaders={ + FakeDataSet: fake_loader_instance + } + ) + expected_dict = { + fake_col_instance: fake_loader_instance, + FakeDataSet.test_col: fake_loader_instance, + } + assert_equal(pipeline_dispatcher.column_loaders, expected_dict) + + pipeline_dispatcher.register( + USEquityPricing.close, fake_loader_instance + ) + expected_dict = { + fake_col_instance: fake_loader_instance, + FakeDataSet.test_col: fake_loader_instance, + USEquityPricing.close: fake_loader_instance, + } + assert_equal(pipeline_dispatcher.column_loaders, expected_dict) + + assert_equal( + pipeline_dispatcher(fake_col_instance), fake_loader_instance + ) + assert_equal( + pipeline_dispatcher(FakeDataSet.test_col), fake_loader_instance + ) + assert_equal( + pipeline_dispatcher(USEquityPricing.close), fake_loader_instance + ) diff --git a/zipline/pipeline/dispatcher.py b/zipline/pipeline/dispatcher.py index 172355d1e8..17acd68abd 100644 --- a/zipline/pipeline/dispatcher.py +++ b/zipline/pipeline/dispatcher.py @@ -18,14 +18,14 @@ def __init__(self, column_loaders=None, dataset_loaders=None): is not None else {} self.column_loaders = mappingproxy(self._column_loaders) if dataset_loaders is not None: - for dataset, pl in dataset_loaders: + for dataset, pl in dataset_loaders.items(): self.register(dataset, pl) def __call__(self, column): if column in self._column_loaders: return self._column_loaders[column] else: - raise LookupError("No pipeline loader registered for %s", column) + raise LookupError("No pipeline loader registered for %s" % column) def register(self, data, pl): """Register a given PipelineLoader to a column or columns of a dataset From 450b7c1c65dde60a11195d366bd4b841753960a9 Mon Sep 17 00:00:00 2001 From: Jacob Nazarenko Date: Wed, 11 Jul 2018 16:27:24 -0400 Subject: [PATCH 3/5] STY: Fix flake8 --- tests/pipeline/test_dispatcher.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/pipeline/test_dispatcher.py b/tests/pipeline/test_dispatcher.py index 5b042c747c..90bf6b1cdc 100644 --- a/tests/pipeline/test_dispatcher.py +++ b/tests/pipeline/test_dispatcher.py @@ -1,6 +1,3 @@ -import os - -from zipline.data import bundles from zipline.pipeline import USEquityPricingLoader from zipline.pipeline.data import ( Column, From 65f582fa5cabe36c447612ff294f8411c6a26ac2 Mon Sep 17 00:00:00 2001 From: Jacob Nazarenko Date: Thu, 12 Jul 2018 10:23:57 -0400 Subject: [PATCH 4/5] BUG: Fix teardown bug with global dispatcher instance --- tests/pipeline/test_dispatcher.py | 7 ++++++- tests/test_examples.py | 3 +++ zipline/pipeline/dispatcher.py | 5 +++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/pipeline/test_dispatcher.py b/tests/pipeline/test_dispatcher.py index 90bf6b1cdc..519ac90206 100644 --- a/tests/pipeline/test_dispatcher.py +++ b/tests/pipeline/test_dispatcher.py @@ -5,7 +5,10 @@ BoundColumn, USEquityPricing, ) -from zipline.pipeline.dispatcher import PipelineDispatcher +from zipline.pipeline.dispatcher import ( + PipelineDispatcher, + clear_all_associations, +) from zipline.pipeline.loaders.base import PipelineLoader from zipline.pipeline.sentinels import NotSpecified from zipline.testing import ZiplineTestCase @@ -45,6 +48,8 @@ def init_class_fixtures(cls): cls.adjustment_reader, ) + cls.add_class_callback(clear_all_associations) + def test_load_not_registered(self): fake_col_instance = FakeColumn( float64_dtype, diff --git a/tests/test_examples.py b/tests/test_examples.py index b069888db3..43bb7c4970 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -21,6 +21,7 @@ from zipline import examples from zipline.data.bundles import register, unregister +from zipline.pipeline.dispatcher import clear_all_associations from zipline.testing import test_resource_path from zipline.testing.fixtures import WithTmpDir, ZiplineTestCase from zipline.testing.predicates import assert_equal @@ -63,6 +64,8 @@ def init_class_fixtures(cls): ) ) + cls.add_class_callback(clear_all_associations) + @parameterized.expand(sorted(examples.EXAMPLE_MODULES)) def test_example(self, example_name): actual_perf = examples.run_example( diff --git a/zipline/pipeline/dispatcher.py b/zipline/pipeline/dispatcher.py index 17acd68abd..7465dbe44c 100644 --- a/zipline/pipeline/dispatcher.py +++ b/zipline/pipeline/dispatcher.py @@ -53,6 +53,11 @@ def register(self, data, pl): raise TypeError("Data provided is neither a BoundColumn " "nor a DataSet") + def clear(self): + """Unregisters all dataset-loader associations""" + self._column_loaders.clear() + global_pipeline_dispatcher = PipelineDispatcher() register_pipeline_loader = global_pipeline_dispatcher.register +clear_all_associations = global_pipeline_dispatcher.clear From cdbf1ebe1b5744a0e039110e99e1a4452e2caa91 Mon Sep 17 00:00:00 2001 From: Jacob Nazarenko Date: Tue, 17 Jul 2018 18:21:38 -0400 Subject: [PATCH 5/5] MAINT: Revise PipelineDispatcher and remove global state --- tests/pipeline/test_dispatcher.py | 71 ++++++------------------------- tests/test_examples.py | 3 -- zipline/pipeline/__init__.py | 5 +-- zipline/pipeline/dispatcher.py | 61 ++++++-------------------- zipline/utils/run_algo.py | 33 +++++++------- 5 files changed, 44 insertions(+), 129 deletions(-) diff --git a/tests/pipeline/test_dispatcher.py b/tests/pipeline/test_dispatcher.py index 519ac90206..6937dfd5a6 100644 --- a/tests/pipeline/test_dispatcher.py +++ b/tests/pipeline/test_dispatcher.py @@ -1,18 +1,13 @@ -from zipline.pipeline import USEquityPricingLoader from zipline.pipeline.data import ( Column, DataSet, BoundColumn, USEquityPricing, ) -from zipline.pipeline.dispatcher import ( - PipelineDispatcher, - clear_all_associations, -) +from zipline.pipeline.dispatcher import PipelineDispatcher from zipline.pipeline.loaders.base import PipelineLoader from zipline.pipeline.sentinels import NotSpecified from zipline.testing import ZiplineTestCase -from zipline.testing.fixtures import WithAdjustmentReader from zipline.testing.predicates import ( assert_raises_str, assert_equal, @@ -38,17 +33,7 @@ class UnrelatedType(object): pass -class PipelineDispatcherTestCase(WithAdjustmentReader, ZiplineTestCase): - - @classmethod - def init_class_fixtures(cls): - super(PipelineDispatcherTestCase, cls).init_class_fixtures() - cls.default_pipeline_loader = USEquityPricingLoader( - cls.bcolz_equity_daily_bar_reader, - cls.adjustment_reader, - ) - - cls.add_class_callback(clear_all_associations) +class PipelineDispatcherTestCase(ZiplineTestCase): def test_load_not_registered(self): fake_col_instance = FakeColumn( @@ -61,39 +46,24 @@ def test_load_not_registered(self): ) fake_pl_instance = FakePipelineLoader() pipeline_dispatcher = PipelineDispatcher( - column_loaders={fake_col_instance: fake_pl_instance} + {fake_col_instance: fake_pl_instance} ) expected_dict = {fake_col_instance: fake_pl_instance} - assert_equal(pipeline_dispatcher.column_loaders, expected_dict) + assert_equal(pipeline_dispatcher._column_loaders, expected_dict) msg = "No pipeline loader registered for %s" % USEquityPricing.close with assert_raises_str(LookupError, msg): pipeline_dispatcher(USEquityPricing.close) def test_register_unrelated_type(self): - pipeline_dispatcher = PipelineDispatcher() fake_pl_instance = FakePipelineLoader() - msg = "Data provided is neither a BoundColumn nor a DataSet" + msg = "%s is neither a BoundColumn nor a DataSet" % UnrelatedType with assert_raises_str(TypeError, msg): - pipeline_dispatcher.register(UnrelatedType, fake_pl_instance) - - def test_passive_registration(self): - pipeline_dispatcher = PipelineDispatcher() - assert_equal(pipeline_dispatcher.column_loaders, {}) - - # imitate user registering a custom pipeline loader first - custom_loader = FakePipelineLoader() - pipeline_dispatcher.register(USEquityPricing.close, custom_loader) - expected_dict = {USEquityPricing.close: custom_loader} - assert_equal(pipeline_dispatcher.column_loaders, expected_dict) - - # now check that trying to register something else won't change it - pipeline_dispatcher.register( - USEquityPricing.close, self.default_pipeline_loader - ) - assert_equal(pipeline_dispatcher.column_loaders, expected_dict) + PipelineDispatcher( + {UnrelatedType: fake_pl_instance} + ) def test_normal_ops(self): fake_loader_instance = FakePipelineLoader() @@ -105,36 +75,19 @@ def test_normal_ops(self): None, {}, ) - pipeline_dispatcher = PipelineDispatcher( - column_loaders={ - fake_col_instance: fake_loader_instance - }, - dataset_loaders={ + pipeline_dispatcher = PipelineDispatcher({ + fake_col_instance: fake_loader_instance, FakeDataSet: fake_loader_instance - } - ) - expected_dict = { - fake_col_instance: fake_loader_instance, - FakeDataSet.test_col: fake_loader_instance, - } - assert_equal(pipeline_dispatcher.column_loaders, expected_dict) + }) - pipeline_dispatcher.register( - USEquityPricing.close, fake_loader_instance - ) expected_dict = { fake_col_instance: fake_loader_instance, FakeDataSet.test_col: fake_loader_instance, - USEquityPricing.close: fake_loader_instance, } - assert_equal(pipeline_dispatcher.column_loaders, expected_dict) - + assert_equal(pipeline_dispatcher._column_loaders, expected_dict) assert_equal( pipeline_dispatcher(fake_col_instance), fake_loader_instance ) assert_equal( pipeline_dispatcher(FakeDataSet.test_col), fake_loader_instance ) - assert_equal( - pipeline_dispatcher(USEquityPricing.close), fake_loader_instance - ) diff --git a/tests/test_examples.py b/tests/test_examples.py index 43bb7c4970..b069888db3 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -21,7 +21,6 @@ from zipline import examples from zipline.data.bundles import register, unregister -from zipline.pipeline.dispatcher import clear_all_associations from zipline.testing import test_resource_path from zipline.testing.fixtures import WithTmpDir, ZiplineTestCase from zipline.testing.predicates import assert_equal @@ -64,8 +63,6 @@ def init_class_fixtures(cls): ) ) - cls.add_class_callback(clear_all_associations) - @parameterized.expand(sorted(examples.EXAMPLE_MODULES)) def test_example(self, example_name): actual_perf = examples.run_example( diff --git a/zipline/pipeline/__init__.py b/zipline/pipeline/__init__.py index 95b7200fe3..5ed5b2418c 100644 --- a/zipline/pipeline/__init__.py +++ b/zipline/pipeline/__init__.py @@ -9,7 +9,7 @@ from .graph import ExecutionPlan, TermGraph from .pipeline import Pipeline from .loaders import USEquityPricingLoader -from .dispatcher import global_pipeline_dispatcher, register_pipeline_loader +from .dispatcher import PipelineDispatcher def engine_from_files(daily_bar_path, @@ -61,6 +61,5 @@ def engine_from_files(daily_bar_path, 'SimplePipelineEngine', 'Term', 'TermGraph', - 'global_pipeline_dispatcher', - 'register_pipeline_loader', + 'PipelineDispatcher', ) diff --git a/zipline/pipeline/dispatcher.py b/zipline/pipeline/dispatcher.py index 7465dbe44c..dbd95c45b3 100644 --- a/zipline/pipeline/dispatcher.py +++ b/zipline/pipeline/dispatcher.py @@ -1,6 +1,4 @@ from zipline.pipeline.data import BoundColumn, DataSet -from zipline.pipeline.loaders.base import PipelineLoader -from zipline.utils.compat import mappingproxy class PipelineDispatcher(object): @@ -8,56 +6,23 @@ class PipelineDispatcher(object): Parameters ---------- - column_loaders : dict[BoundColumn -> PipelineLoader] - Map from columns to pipeline loader for those columns. - dataset_loaders : dict[DataSet -> PipelineLoader] - Map from datasets to pipeline loader for those datasets. + loaders : dict[BoundColumn or DataSet -> PipelineLoader] + Map from columns or datasets to pipeline loader for those objects. """ - def __init__(self, column_loaders=None, dataset_loaders=None): - self._column_loaders = column_loaders if column_loaders \ - is not None else {} - self.column_loaders = mappingproxy(self._column_loaders) - if dataset_loaders is not None: - for dataset, pl in dataset_loaders.items(): - self.register(dataset, pl) + def __init__(self, loaders): + self._column_loaders = {} + for data, pl in loaders.items(): + if isinstance(data, BoundColumn): + self._column_loaders[data] = pl + elif issubclass(data, DataSet): + for c in data.columns: + self._column_loaders[c] = pl + else: + raise TypeError("%s is neither a BoundColumn " + "nor a DataSet" % data) def __call__(self, column): if column in self._column_loaders: return self._column_loaders[column] else: raise LookupError("No pipeline loader registered for %s" % column) - - def register(self, data, pl): - """Register a given PipelineLoader to a column or columns of a dataset - - Parameters - ---------- - data : BoundColumn or DataSet - The column or dataset for which to register the PipelineLoader - pl : PipelineLoader - The PipelineLoader to register for the column or dataset columns - """ - assert isinstance(pl, PipelineLoader) - - # make it so that in either case nothing will happen if the column is - # already registered, allowing users to register their own loaders - # early on in extensions - if isinstance(data, BoundColumn): - if data not in self._column_loaders: - self._column_loaders[data] = pl - elif issubclass(data, DataSet): - for c in data.columns: - if c not in self._column_loaders: - self._column_loaders[c] = pl - else: - raise TypeError("Data provided is neither a BoundColumn " - "nor a DataSet") - - def clear(self): - """Unregisters all dataset-loader associations""" - self._column_loaders.clear() - - -global_pipeline_dispatcher = PipelineDispatcher() -register_pipeline_loader = global_pipeline_dispatcher.register -clear_all_associations = global_pipeline_dispatcher.clear diff --git a/zipline/utils/run_algo.py b/zipline/utils/run_algo.py index aff04a7235..2d654828a1 100644 --- a/zipline/utils/run_algo.py +++ b/zipline/utils/run_algo.py @@ -22,10 +22,7 @@ from zipline.data.data_portal import DataPortal from zipline.finance import metrics from zipline.finance.trading import TradingEnvironment -from zipline.pipeline import ( - register_pipeline_loader, - global_pipeline_dispatcher, -) +from zipline.pipeline import PipelineDispatcher from zipline.pipeline.data import USEquityPricing from zipline.pipeline.loaders import USEquityPricingLoader from zipline.utils.factory import create_simulation_parameters @@ -76,7 +73,8 @@ def _run(handle_data, print_algo, metrics_set, local_namespace, - environ): + environ, + pipeline_dispatcher): """Run a backtest for the given algorithm. This is shared between the cli and :func:`zipline.run_algo`. @@ -137,7 +135,7 @@ def _run(handle_data, ), ) - if bundle is not None: + if bundle is not None and pipeline_dispatcher is None: bundle_data = bundles.load( bundle, environ, @@ -171,16 +169,14 @@ def _run(handle_data, bundle_data.adjustment_reader, ) - # we register our default loader last, after any loaders from users - # have been registered via extensions - register_pipeline_loader(USEquityPricing, pipeline_loader) - - def choose_loader(column): - return global_pipeline_dispatcher(column) + # create the default dispatcher + pipeline_dispatcher = PipelineDispatcher( + {USEquityPricing: pipeline_loader} + ) else: env = TradingEnvironment(environ=environ) - choose_loader = None + pipeline_dispatcher = None if isinstance(metrics_set, six.string_types): try: @@ -191,7 +187,7 @@ def choose_loader(column): perf = TradingAlgorithm( namespace=namespace, env=env, - get_pipeline_loader=choose_loader, + get_pipeline_loader=pipeline_dispatcher, trading_calendar=trading_calendar, sim_params=create_simulation_parameters( start=start, @@ -235,7 +231,7 @@ def load_extensions(default, extensions, strict, environ, reload=False): ---------- default : bool Load the default exension (~/.zipline/extension.py)? - extension : iterable[str] + extensions : iterable[str] The paths to the extensions to load. If the path ends in ``.py`` it is treated as a script and executed. If it does not end in ``.py`` it is treated as a module to be imported. @@ -292,7 +288,8 @@ def run_algorithm(start, default_extension=True, extensions=(), strict_extensions=True, - environ=os.environ): + environ=os.environ, + pipeline_dispatcher=None): """Run a trading algorithm. Parameters @@ -351,6 +348,9 @@ def run_algorithm(start, environ : mapping[str -> str], optional The os environment to use. Many extensions use this to get parameters. This defaults to ``os.environ``. + pipeline_dispatcher : PipelineDispatcher, optional + The pipeline dispatcher to use, which should contains any column-to- + loader associations necessary to run the trading algorithm Returns ------- @@ -403,4 +403,5 @@ def run_algorithm(start, metrics_set=metrics_set, local_namespace=False, environ=environ, + pipeline_dispatcher=pipeline_dispatcher, )