diff --git a/src/venturescript/vscgpm.py b/src/venturescript/vscgpm.py index 86debba1..ea717108 100644 --- a/src/venturescript/vscgpm.py +++ b/src/venturescript/vscgpm.py @@ -172,17 +172,15 @@ def from_metadata(cls, metadata, rng=None): # Internal helpers. def _predict_cell(self, rowid, target, inputs, label): - inputs_list = [inputs[i] for i in self.inputs] - sp_args = str.join(' ', map(str, [rowid] + inputs_list)) - i = self.outputs.index(target) - return self.ripl.predict( - '((lookup outputs %i) %s)' % (i, sp_args), label=label) + output_idx = self.outputs.index(target) + sp_args = self._get_sp_args(rowid, inputs) + return self.ripl.predict('((lookup outputs %i) %s)' + % (output_idx, sp_args), label=label) def _observe_cell(self, rowid, query, value, inputs): output_idx = self.outputs.index(query) - inputs_list = [inputs[i] for i in self.inputs] label = self._gen_label() - sp_args = '%d %s' % (rowid, ' '.join(map(str, inputs_list))) + sp_args = self._get_sp_args(rowid, inputs) if not self.obs_override: self.ripl.observe('((lookup outputs %i) %s)' % (output_idx, sp_args), value, label=label) @@ -204,6 +202,11 @@ def _gen_label(self): self.rng.randint(1,100), datetime.now().strftime('%Y%m%d%H%M%S%f')) + def _get_sp_args(self, rowid, inputs): + sp_rowid = '(atom %d)' % (rowid,) + sp_input = ' '.join(map(str, [inputs[i] for i in self.inputs])) + return '%s %s' % (sp_rowid, sp_input) + def _validate_incorporate(self, rowid, observation, inputs=None): inputs = inputs or {} if not observation: