-
Notifications
You must be signed in to change notification settings - Fork 49
API Use Cases
- Incorporate statically-generated weights in facts, features
- Adapt duplicate facts checker to larger databases
- Write output to disk as they are finished, not as they were read
- Trace weight behavior of a feature during training
This is useful for things like text categorization where you want ProPPR to respond to something like TF/IDF on the link between a document and a word.
Here's an example, which you can also see in examples/walkthru/textcat
:
# Simple program for text classification, illustrating how to attach a
# classifier to a ProPPR rule.
# Pick a label Y for X, and decide if it's a good classification by
# calling ab_classify.
predict(X,Y) :- isLabel(Y), ab_classify(X,Y).
# "abduce" (i.e., guess at) a classification Y for document X. The
# antecedent of the rule is empty, so it always succeeds, but the
# weight for this rule will be based on features generated by the
# annotation { f(W,Y): hasWord(X,W) } -- ie, the words in document X
# will be paired with a label, and used as features. Note that the
# weight of this rule will compete with the weight of the implicit
# 'reset' rule.
ab_classify(X,Y) :- { f#(Word,Y,Weight): hasWord#(X,Word,Weight) }.
The #
denotes a weighted relation. In weighted relations, the last argument is used to stand in for the weight of the fact, and does not participate in the final arity (so in your learned-params file you'll only have f(house,pos)
and not a thousand f(house,pos,23.12344356)
with different weights in).
The weights are specified in a db file. In this case, it's a graph file:
$ cat toycorpus.graph
hasWord train00001 a 0.0589
hasWord train00001 house 0.5005
hasWord train00001 pricy 0.5005
hasWord train00001 doll 0.7040
hasWord train00002 a 0.0510
hasWord train00002 fire 0.6097
hasWord train00002 little 0.4334
hasWord train00002 truck 0.6097
...
ProPPR knows that the last element on the line must be a weight, since .graph
files only support arity-2 relations.
You can also specify weights on facts of arbitrary arity in a .facts
file, but you have to tell ProPPR explicitly that the last element is a weight and not part of the relation. For simplicity we use the same syntax as in the rules file, by adding a #
to the end of the functor:
$ cat toycorpus.facts
hasWord# train00001 a 0.0589
hasWord# train00001 house 0.5005
hasWord# train00001 pricy 0.5005
hasWord# train00001 doll 0.7040
hasWord# train00002 a 0.0510
hasWord# train00002 fire 0.6097
hasWord# train00002 little 0.4334
hasWord# train00002 truck 0.6097
...
Weights are supported in facts, graph, and sparsegraph plugins as of commit ce79246cd
.
Duplicate lines in a fact file show up in proof graphs as edges with increased weight, which is usually not what you meant. By default, FactsPlugin (*.facts
) and GraphPlugin (*.graph
) throw each fact into a bloom filter with a false positive probability of 0.00001 and an estimated size of 1,000,000. This lets them easily check whether a fact has been seen before, without having to store each fact in a lookup-friendly format. If a duplicate fact is detected, it prints an error message and skips the fact. In theory, this means a false positive can result in an incomplete database, but in practice we've yet to run in to this issue.
If you load more then 1M facts, it'll print a message warning you about a greater likelihood of false positives.
You can increase the size of the bloom filter to better accommodate large fact databases using the --duplicateCheck
option: just set to the number of lines in your fact file.
Turn off duplicate checking entirely by setting --duplicateCheck -1
.
By default, QueryAnswerer and Grounder write output files so that the nth example in the query file is the nth example in the solutions/grounded file, and so on. If your examples vary in complexity, sometimes this causes a pile of finished examples to sit around in memory waiting for a slow example to finish. For example, if query 1 is very slow to compute, and queries 2-16 are very fast to compute, a 16-thread run will begin work on all 16 examples. 15 of the threads will finish, and continue computing queries later on in the file, but nothing will be written to disk until query 1 finishes. If you don't need your output to be in the same order as the query file, you can save memory by enabling reordering.
Command line option:
--order reorder
How it works:
The Multithreading harness uses the Futures pattern to manage tasks. Each query generates a "transformer" task and a "cleanup" task, where the output of the transformer is the input to the cleanup. The transformer pool uses #nthreads threads (specified using --threads
), and the cleanup pool uses only one thread, since letting multiple threads write to disk is problematic. Java guarantees that tasks will be picked up in the order in which they were added to the pool. By default, the cleanup task blocks until the transformer has finished. With --order reorder
, the cleanup task waits a maximum of 20ms for the transformer to finish, and if it hasn't, it resubmits the cleanup job, placing it at the end of the queue. This frees the cleanup thread to proceed to the next example.
Command line option:
--srw "traceFeature=db(LightweightGraphPlugin,webkb.graph)[:...]"
log4j spec:
log4j.logger.edu.cmu.ml.proppr.learn.SRW=INFO
log4j.logger.edu.cmu.ml.proppr.learn.PosNegLossTrainedSRW=INFO
Example log4j appender:
log4j.appender.consoleout.layout.ConversionPattern=%5p [%c{1}] %t %m%n
where 5p=log level, c=class, t=thread, m=message, n=newline
Output syntax:
-
reg: lists regularization component of gradient for traced feature.
INFO [PosNegLossTrainedSRW] transformer-1 trace db(LightweightGraphPlugin,webkb.graph) reg 0.002012184805936451
-
logP: lists log-loss, positive-label component of gradient for traced feature.
INFO [PosNegLossTrainedSRW] transformer-13 trace db(LightweightGraphPlugin,webkb.graph) logP -7.394501070610791E-4 / 0.022347119637303438 = -0.0330892803664386
-
logN: lists log-loss, negative-label component of gradient for traced feature.
INFO [PosNegLossTrainedSRW] transformer-13 trace db(LightweightGraphPlugin,webkb.graph) logN 1.0 * 7.389713595330534E-4 / 0.9776673487199024 = 7.558515281303167E-4
-
was: lists old parameter value for traced feature.
INFO [SRW] transformer-13 trace db(LightweightGraphPlugin,webkb.graph) was 1.0060924029682254
-
+=: lists total gradient and new parameter value for traced feature.
INFO [SRW] transformer-13 trace db(LightweightGraphPlugin,webkb.graph) += 0.02653948565577576 = 1.0326318886240011
Sample output:
INFO [Trainer] main
edu.cmu.ml.proppr.util.ModuleConfiguration
queries file: train_no_cornell.data.grounded
params file: cornell.params
Walker: edu.cmu.ml.proppr.learn.L2PosNegLossTrainedSRW
Weighting Scheme: edu.cmu.ml.proppr.learn.tools.ReLUWeightingScheme
Alpha: 0.1
Epsilon: 1.0E-4
Max depth: 5
Strategy: exception
INFO [Trainer] main Training model parameters on train_no_cornell.data.grounded...
INFO [Trainer] main epoch 1 ...
INFO [PosNegLossTrainedSRW] transformer-1 trace db(LightweightGraphPlugin,webkb.graph) reg 0.002012184805936451
INFO [PosNegLossTrainedSRW] transformer-2 trace db(LightweightGraphPlugin,webkb.graph) reg 0.002012184805936451
INFO [PosNegLossTrainedSRW] transformer-3 trace db(LightweightGraphPlugin,webkb.graph) reg 0.002012184805936451
INFO [PosNegLossTrainedSRW] transformer-4 trace db(LightweightGraphPlugin,webkb.graph) reg 0.002012184805936451
INFO [PosNegLossTrainedSRW] transformer-6 trace db(LightweightGraphPlugin,webkb.graph) reg 0.002012184805936451
INFO [PosNegLossTrainedSRW] transformer-5 trace db(LightweightGraphPlugin,webkb.graph) reg 0.002012184805936451
INFO [PosNegLossTrainedSRW] transformer-7 trace db(LightweightGraphPlugin,webkb.graph) reg 0.002012184805936451
INFO [PosNegLossTrainedSRW] transformer-8 trace db(LightweightGraphPlugin,webkb.graph) reg 0.002012184805936451
INFO [PosNegLossTrainedSRW] transformer-9 trace db(LightweightGraphPlugin,webkb.graph) reg 0.002012184805936451
INFO [PosNegLossTrainedSRW] transformer-13 trace db(LightweightGraphPlugin,webkb.graph) reg 0.002012184805936451
INFO [PosNegLossTrainedSRW] transformer-11 trace db(LightweightGraphPlugin,webkb.graph) reg 0.002012184805936451
INFO [PosNegLossTrainedSRW] transformer-10 trace db(LightweightGraphPlugin,webkb.graph) reg 0.002012184805936451
INFO [PosNegLossTrainedSRW] transformer-12 trace db(LightweightGraphPlugin,webkb.graph) reg 0.002012184805936451
INFO [PosNegLossTrainedSRW] transformer-14 trace db(LightweightGraphPlugin,webkb.graph) reg 0.002012184805936451
INFO [PosNegLossTrainedSRW] transformer-16 trace db(LightweightGraphPlugin,webkb.graph) reg 0.002012184805936451
INFO [PosNegLossTrainedSRW] transformer-15 trace db(LightweightGraphPlugin,webkb.graph) reg 0.002012184805936451
INFO [PosNegLossTrainedSRW] transformer-13 trace db(LightweightGraphPlugin,webkb.graph) logP -7.394501070610791E-4 / 0.022347119637303438 = -0.0330892803664386
INFO [PosNegLossTrainedSRW] transformer-13 trace db(LightweightGraphPlugin,webkb.graph) logN 1.0 * 7.389713595330534E-4 / 0.9776673487199024 = 7.558515281303167E-4
INFO [PosNegLossTrainedSRW] transformer-13 trace db(LightweightGraphPlugin,webkb.graph) logN 1.0 * 7.395496387465339E-4 / 0.9776498723889857 = 7.56456538923664E-4
INFO [PosNegLossTrainedSRW] transformer-13 trace db(LightweightGraphPlugin,webkb.graph) logN 1.0 * 7.391368182232194E-4 / 0.9776623483485334 = 7.560246331177311E-4
INFO [PosNegLossTrainedSRW] transformer-13 trace db(LightweightGraphPlugin,webkb.graph) logN 1.0 * 7.393253306708433E-4 / 0.9776566512633886 = 7.562218593976129E-4
INFO [PosNegLossTrainedSRW] transformer-13 trace db(LightweightGraphPlugin,webkb.graph) logN 1.0 * 7.395527564385326E-4 / 0.9776497781683812 = 7.564598007929574E-4
INFO [PosNegLossTrainedSRW] transformer-13 trace db(LightweightGraphPlugin,webkb.graph) logN 1.0 * 7.396824994997896E-4 / 0.9776458571685945 = 7.565955443641099E-4
INFO [SRW] transformer-13 trace db(LightweightGraphPlugin,webkb.graph) was 1.0060924029682254
INFO [SRW] transformer-13 trace db(LightweightGraphPlugin,webkb.graph) += 0.02653948565577576 = 1.0326318886240011
INFO [PosNegLossTrainedSRW] transformer-13 trace db(LightweightGraphPlugin,webkb.graph) reg 8.571428571428571E-4
INFO [PosNegLossTrainedSRW] transformer-6 trace db(LightweightGraphPlugin,webkb.graph) logP -6.585123213126203E-4 / 0.022450301786658716 = -0.029332003087100898
INFO [PosNegLossTrainedSRW] transformer-6 trace db(LightweightGraphPlugin,webkb.graph) logN 1.0 * 6.582897112534067E-4 / 0.9775573496978436 = 6.734026514729593E-4
INFO [PosNegLossTrainedSRW] transformer-6 trace db(LightweightGraphPlugin,webkb.graph) logN 1.0 * 6.566761957809852E-4 / 0.9775676082696231 = 6.71745043745217E-4
INFO [PosNegLossTrainedSRW] transformer-6 trace db(LightweightGraphPlugin,webkb.graph) logN 1.0 * 6.580624930792297E-4 / 0.9775542233663602 = 6.731723697260384E-4
INFO [PosNegLossTrainedSRW] transformer-6 trace db(LightweightGraphPlugin,webkb.graph) logN 1.0 * 6.580843757838968E-4 / 0.9775641812563434 = 6.731878974310839E-4
INFO [PosNegLossTrainedSRW] transformer-6 trace db(LightweightGraphPlugin,webkb.graph) logN 1.0 * 6.565016397595402E-4 / 0.9775735812671842 = 6.715623788733601E-4
INFO [PosNegLossTrainedSRW] transformer-6 trace db(LightweightGraphPlugin,webkb.graph) logN 1.0 * 6.567307835602201E-4 / 0.9775657452003786 = 6.718021644934022E-4
INFO [SRW] transformer-6 trace db(LightweightGraphPlugin,webkb.graph) was 0.42857142857142855
INFO [SRW] transformer-6 trace db(LightweightGraphPlugin,webkb.graph) += 0.02328494577542239 = 0.45185637434685094