|
7 | 7 | from itertools import combinations, product
|
8 | 8 | from logic import (FolKB, constant_symbols, predicate_symbols, standardize_variables,
|
9 | 9 | variables, is_definite_clause, subst, expr, Expr)
|
| 10 | +from functools import partial |
10 | 11 |
|
11 | 12 | # ______________________________________________________________________________
|
12 | 13 |
|
@@ -297,44 +298,59 @@ def new_literals(self, clause):
|
297 | 298 | share_vars = variables(clause[0])
|
298 | 299 | for l in clause[1]:
|
299 | 300 | share_vars.update(variables(l))
|
300 |
| - |
301 | 301 | for pred, arity in self.pred_syms:
|
302 | 302 | new_vars = {standardize_variables(expr('x')) for _ in range(arity - 1)}
|
303 | 303 | for args in product(share_vars.union(new_vars), repeat=arity):
|
304 | 304 | if any(var in share_vars for var in args):
|
305 |
| - yield Expr(pred, *[var for var in args]) |
| 305 | + # make sure we don't return an existing rule |
| 306 | + if not Expr(pred, args) in clause[1]: |
| 307 | + yield Expr(pred, *[var for var in args]) |
306 | 308 |
|
307 |
| - def choose_literal(self, literals, examples): |
308 |
| - """Choose the best literal based on the information gain.""" |
309 |
| - def gain(l): |
310 |
| - pre_pos = len(examples[0]) |
311 |
| - pre_neg = len(examples[1]) |
312 |
| - extended_examples = [sum([list(self.extend_example(example, l)) for example in |
313 |
| - examples[i]], []) for i in range(2)] |
314 |
| - post_pos = len(extended_examples[0]) |
315 |
| - post_neg = len(extended_examples[1]) |
316 |
| - if pre_pos + pre_neg == 0 or post_pos + post_neg == 0: |
317 |
| - return -1 |
318 | 309 |
|
319 |
| - # number of positive example that are represented in extended_examples |
320 |
| - T = 0 |
321 |
| - for example in examples[0]: |
322 |
| - def represents(d): |
323 |
| - return all(d[x] == example[x] for x in example) |
324 |
| - if any(represents(l_) for l_ in extended_examples[0]): |
325 |
| - T += 1 |
| 310 | + def choose_literal(self, literals, examples): |
| 311 | + """Choose the best literal based on the information gain.""" |
326 | 312 |
|
327 |
| - return T * log((post_pos*(pre_pos + pre_neg) + 1e-4) / ((post_pos + post_neg)*pre_pos)) |
| 313 | + return max(literals, key = partial(self.gain , examples = examples)) |
| 314 | + |
| 315 | + |
| 316 | + def gain(self, l ,examples): |
| 317 | + """ |
| 318 | + Find the utility of each literal when added to the body of the clause. |
| 319 | + Utility function is: |
| 320 | + gain(R, l) = T * (log_2 (post_pos / (post_pos + post_neg)) - log_2 (pre_pos / (pre_pos + pre_neg))) |
| 321 | +
|
| 322 | + where: |
| 323 | + |
| 324 | + pre_pos = number of possitive bindings of rule R (=current set of rules) |
| 325 | + pre_neg = number of negative bindings of rule R |
| 326 | + post_pos = number of possitive bindings of rule R' (= R U {l} ) |
| 327 | + post_neg = number of negative bindings of rule R' |
| 328 | + T = number of possitive bindings of rule R that are still covered |
| 329 | + after adding literal l |
| 330 | +
|
| 331 | + """ |
| 332 | + pre_pos = len(examples[0]) |
| 333 | + pre_neg = len(examples[1]) |
| 334 | + post_pos = sum([list(self.extend_example(example, l)) for example in examples[0]], []) |
| 335 | + post_neg = sum([list(self.extend_example(example, l)) for example in examples[1]], []) |
| 336 | + if pre_pos + pre_neg ==0 or len(post_pos) + len(post_neg)==0: |
| 337 | + return -1 |
| 338 | + # number of positive example that are represented in extended_examples |
| 339 | + T = 0 |
| 340 | + for example in examples[0]: |
| 341 | + represents = lambda d: all(d[x] == example[x] for x in example) |
| 342 | + if any(represents(l_) for l_ in post_pos): |
| 343 | + T += 1 |
| 344 | + value = T * (log(len(post_pos) / (len(post_pos) + len(post_neg)) + 1e-12,2) - log(pre_pos / (pre_pos + pre_neg),2)) |
| 345 | + return value |
328 | 346 |
|
329 |
| - return max(literals, key=gain) |
330 | 347 |
|
331 | 348 | def update_examples(self, target, examples, extended_examples):
|
332 | 349 | """Add to the kb those examples what are represented in extended_examples
|
333 | 350 | List of omitted examples is returned."""
|
334 | 351 | uncovered = []
|
335 | 352 | for example in examples:
|
336 |
| - def represents(d): |
337 |
| - return all(d[x] == example[x] for x in example) |
| 353 | + represents = lambda d: all(d[x] == example[x] for x in example) |
338 | 354 | if any(represents(l) for l in extended_examples):
|
339 | 355 | self.tell(subst(example, target))
|
340 | 356 | else:
|
@@ -400,3 +416,8 @@ def false_positive(e, h):
|
400 | 416 |
|
401 | 417 | def false_negative(e, h):
|
402 | 418 | return e["GOAL"] and not guess_value(e, h)
|
| 419 | + |
| 420 | + |
| 421 | + |
| 422 | + |
| 423 | + |
0 commit comments