Skip to content

Commit

Permalink
DEV: add testing utility for checking term lookback windows
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Jevnik authored and llllllllll committed Feb 25, 2019
1 parent 6110ce3 commit 05a6080
Showing 1 changed file with 87 additions and 0 deletions.
87 changes: 87 additions & 0 deletions zipline/pipeline/factors/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import numpy as np

from zipline.testing.predicates import assert_equal
from .factor import CustomFactor


class IDBox(object):
"""A wrapper that hashs to the id of the underlying object and compares
equality on the id of the underlying.
Parameters
----------
ob : any
The object to wrap.
Attributes
----------
ob : any
The object being wrapped.
Notes
-----
This is useful for storing non-hashable values in a set or dict.
"""
def __init__(self, ob):
self.ob = ob

def __hash__(self):
return id(self)

def __eq__(self, other):
if not isinstance(other, IDBox):
return NotImplemented

return id(self.ob) == id(other.ob)


class CheckWindowsFactor(CustomFactor):
"""A custom factor that makes assertions about the lookback windows that
it gets passed.
Parameters
----------
input_ : Term
The input term to the factor.
window_length : int
The length of the lookback window.
expected_windows : dict[int, dict[pd.Timestamp, np.ndarray]]
For each asset, for each day, what the expected lookback window is.
Notes
-----
The output of this factor is the same as ``Latest``. Any assets or days
not in ``expected_windows`` are not checked.
"""
params = ('expected_windows',)

def __new__(cls, input_, window_length, expected_windows):
return super(CheckWindowsFactor, cls).__new__(
cls,
inputs=[input_],
dtype=input_.dtype,
window_length=window_length,
expected_windows=frozenset(
(k, IDBox(v)) for k, v in expected_windows.items()
),
)

def compute(self, today, assets, out, input_, expected_windows):
for asset, expected_by_day in expected_windows:
expected_by_day = expected_by_day.ob

col_ix = np.searchsorted(assets, asset)
if assets[col_ix] != asset:
raise AssertionError('asset %s is not in the window' % asset)

try:
expected = expected_by_day[today]
except KeyError:
pass
else:
expected = np.array(expected)
actual = input_[:, col_ix]
assert_equal(actual, expected)

# output is just latest
out[:] = input_[-1]

0 comments on commit 05a6080

Please sign in to comment.