Skip to content

Commit

Permalink
Fix #242, use atom for rowids only.
Browse files Browse the repository at this point in the history
  • Loading branch information
Feras A Saad committed Jan 15, 2018
1 parent 9e1650d commit 97980cb
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions src/venturescript/vscgpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,15 @@ def from_metadata(cls, metadata, rng=None):
# Internal helpers.

def _predict_cell(self, rowid, target, inputs, label):
inputs_list = [inputs[i] for i in self.inputs]
sp_args = str.join(' ', map(str, [rowid] + inputs_list))
i = self.outputs.index(target)
return self.ripl.predict(
'((lookup outputs %i) %s)' % (i, sp_args), label=label)
output_idx = self.outputs.index(target)
sp_args = self._get_sp_args(rowid, inputs)
return self.ripl.predict('((lookup outputs %i) %s)'
% (output_idx, sp_args), label=label)

def _observe_cell(self, rowid, query, value, inputs):
output_idx = self.outputs.index(query)
inputs_list = [inputs[i] for i in self.inputs]
label = self._gen_label()
sp_args = '%d %s' % (rowid, ' '.join(map(str, inputs_list)))
sp_args = self._get_sp_args(rowid, inputs)
if not self.obs_override:
self.ripl.observe('((lookup outputs %i) %s)'
% (output_idx, sp_args), value, label=label)
Expand All @@ -204,6 +202,11 @@ def _gen_label(self):
self.rng.randint(1,100),
datetime.now().strftime('%Y%m%d%H%M%S%f'))

def _get_sp_args(self, rowid, inputs):
sp_rowid = '(atom %d)' % (rowid,)
sp_input = ' '.join(map(str, [inputs[i] for i in self.inputs]))
return '%s %s' % (sp_rowid, sp_input)

def _validate_incorporate(self, rowid, observation, inputs=None):
inputs = inputs or {}
if not observation:
Expand Down

0 comments on commit 97980cb

Please sign in to comment.