Skip to content

Commit

Permalink
Fix #240, default implementation of observers.
Browse files Browse the repository at this point in the history
Requries user who overrides one observer to override them all.
Consider using a dictionary for user to override only some observers.
  • Loading branch information
Feras A Saad committed Jan 15, 2018
1 parent 2fe0e9f commit 93331de
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 14 deletions.
33 changes: 27 additions & 6 deletions src/venturescript/vscgpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import base64
import copy
import math
Expand All @@ -23,6 +24,8 @@

import venture.shortcuts as vs

from venture.exception import VentureException

from cgpm.cgpm import CGpm
from cgpm.utils import config as cu
from cgpm.utils import general as gu
Expand Down Expand Up @@ -71,7 +74,9 @@ def __init__(self, outputs, inputs, rng=None, sp=None, **kwargs):
raise ValueError('source.inputs list disagrees with inputs.')
self.inputs = inputs
# Check overriden observers.
if len(self.outputs) != self.ripl.evaluate('(size observers)'):
num_observers = self._get_num_observers()
self.obs_override = num_observers is not None
if self.obs_override and len(self.outputs) != num_observers:
raise ValueError('source.observers list disagrees with outputs.')
# XXX Eliminate this nested defaultdict
# Inputs and labels for incorporate/unincorporate.
Expand Down Expand Up @@ -174,12 +179,20 @@ def _predict_cell(self, rowid, target, inputs, label):
'((lookup outputs %i) %s)' % (i, sp_args), label=label)

def _observe_cell(self, rowid, query, value, inputs):
output_id = self.outputs.index(query)
inputs_list = [inputs[i] for i in self.inputs]
label = '\''+self._gen_label()
sp_args = str.join(' ', map(str, [rowid] + inputs_list + [value, label]))
i = self.outputs.index(query)
self.ripl.evaluate('((lookup observers %i) %s)' % (i, sp_args))
self.obs[rowid]['labels'][query] = label[1:]
label = self._gen_label()
if self.obs_override:
qlabel = '(quote %s)' % (label,)
sp_args = ' '.join(map(str,
itertools.chain([rowid], inputs_list, [value, qlabel])))
self.ripl.evaluate('((lookup observers %i) %s)'
% (output_id, sp_args))
else:
sp_args = ' '.join(map(str, itertools.chain([rowid], inputs_list)))
self.ripl.observe('((lookup outputs %i) %s)'
% (output_id, sp_args), value, label=label)
self.obs[rowid]['labels'][query] = label

def _forget_cell(self, rowid, query):
if query not in self.obs[rowid]['labels']:
Expand Down Expand Up @@ -255,6 +268,14 @@ def _check_matched_inputs(self, rowid, inputs):
raise ValueError('Given inputs contradicts dataset: %d, %s, %s' %
(rowid, inputs, self.obs[rowid]['inputs']))

def _get_num_observers(self):
# Return the length of the "observers" list defined by the client, or
# None if the client did not override the observers.
try:
return self.ripl.evaluate('(size observers)')
except VentureException:
return None

@staticmethod
def _obs_to_json(obs):
def convert_key_int_to_str(d):
Expand Down
44 changes: 36 additions & 8 deletions tests/test_vscgpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,13 @@
(lambda (rowid w value label)
(observe (simulate_y ,rowid ,w) value ,label))]
[define observers (list observe_m
observe_y)]
[define inputs (list 'w)]
[define transition
(lambda (N)
(mh default one N))]
"""
"""

source_concrete = """
define make_cgpm = () -> {
Expand Down Expand Up @@ -100,19 +97,42 @@
$label: observe simulate_y($rowid, $w) = value;
};
define observers = [observe_m, observe_y];
define inputs = ["w"];
define transition = (N) -> {
mh(default, one, N)
};
"""

# Define source with client overriding observers.
source_abstract_observers_good = source_abstract + \
'[define observers (list observe_m observe_y)]\n'
source_abstract_observers_bad = source_abstract + \
'[define observers (list observe_m observe_y 2)]\n'

source_concrete_observers_good = source_concrete + \
'define observers = [observe_m, observe_y];\n'
source_concrete_observers_bad = source_concrete + \
'define observers = [observe_m, observe_y, 2];\n'

# Define test cases.
Case = namedtuple('Case', ['source', 'mode'])
cases = [
Case(source_abstract, 'church_prime'),
Case(source_concrete, 'venture_script'),
Case(source_abstract, 'church_prime'),
Case(source_concrete, 'venture_script'),
Case(source_abstract_observers_good, 'church_prime'),
Case(source_concrete_observers_good, 'venture_script'),
]

CaseObs = namedtuple('Case', ['source', 'obsok', 'mode'])
casesObs = [
CaseObs(source_abstract, True, 'church_prime'),
CaseObs(source_concrete, True, 'venture_script'),
CaseObs(source_abstract_observers_good, True, 'church_prime'),
CaseObs(source_concrete_observers_good, True, 'venture_script'),
CaseObs(source_abstract_observers_bad, False, 'church_prime'),
CaseObs(source_concrete_observers_bad, False, 'venture_script'),
]

@pytest.mark.parametrize('case', cases)
Expand All @@ -133,6 +153,14 @@ def test_wrong_inputs(case):
with pytest.raises(ValueError):
VsCGpm(outputs=[1,2], inputs=[3,4], source=case.source, mode=case.mode)

@pytest.mark.parametrize('case', casesObs)
def test_wrong_observers(case):
try:
VsCGpm(outputs=[0,1], inputs=[2], source=case.source, mode=case.mode)
assert case.obsok
except ValueError:
assert not case.obsok

@pytest.mark.parametrize('case', cases)
def test_incorporate_unincorporate(case):
cgpm = VsCGpm(outputs=[0,1], inputs=[3], source=case.source, mode=case.mode)
Expand Down

0 comments on commit 93331de

Please sign in to comment.