forked from erelsgl-nlp/languagemodel
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCrossLanguageModel.js
214 lines (187 loc) · 8.34 KB
/
CrossLanguageModel.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
var LanguageModel = require('./LanguageModel');
var logSumExp = require('./logSumExp');
var extend = require('util')._extend;
/**
* This class represents a model for two different languages - input language and output language.
* Based on:
*
* Leuski Anton, Traum David. A Statistical Approach for Text Processing in Virtual Humans tech. rep.University of Southern California, Institute for Creative Technologies 2008.
* http://www.citeulike.org/user/erelsegal-halevi/article/12540655
*
* @author Erel Segal-Halevi
* @since 2013-08
*
* opts - may contain the following options:
* * smoothingCoefficient - the lambda-factor for smoothing the unigram probabilities.
*/
var CrossLanguageModel = function(opts) {
this.smoothingCoefficient = opts.smoothingCoefficient || 0.9;
this.inputLanguageModel = new LanguageModel(opts);
this.outputLanguageModel = new LanguageModel(opts);
}
CrossLanguageModel.prototype = {
/**
* Tell the model that the given sample belongs to the given classes.
*
* @param sample
* a document.
* @param classes
* an object whose KEYS are classes, or an array whose VALUES are classes.
*/
trainOnline: function(features, labels) {
throw new Error("CrossLanguageModel does not support online training");
},
/**
* Train the model with all the given documents.
*
* @param dataset
* an array with objects of the format:
* {input: {feature1:count1, feature2:count2,...}, output: {feature1:count1, feature2:count2,...}}
*/
trainBatch : function(dataset) {
this.inputLanguageModel.trainBatch(dataset.map(function(datum) {return datum.input;}));
this.outputLanguageModel.trainBatch(dataset.map(function(datum) {return datum.output;}));
this.outputFeatures = extend({}, this.outputLanguageModel.getAllWordCounts());
delete this.outputFeatures["_total"];
},
/**
* Train the model with all the given documents.
*
* @param dataset
* an array with objects of the format:
* {input: {feature1:count1, feature2:count2,...}, output: {feature1:count1, feature2:count2,...}}
*/
trainOnline: function(input, output) {
throw new Error("CrossLanguageModel currently does not support online training");
},
getAllWordCounts: function() {
return this.mapWordToTotalCount;
},
/**
* Calculate the Kullback-Leibler divergence between the language models of the given samples.
* This can be used as an approximation of the (inverse) semantic similarity. between them.
*
* @param inputSentenceCounts (hash) represents a sentence from the INPUT domain.
* @param outputSentenceCounts (hash) represents a sentence from the OUTPUT domain.
*
* @note divergence is not symmetric - divergence(a,b) != divergence(b,a).
*/
divergence: function(inputSentenceCounts, outputSentenceCounts) { // (6) D(P(W)||P(F)) = ...
if (outputSentenceCounts!==Object(outputSentenceCounts))
throw new Error("expected outputSentenceCounts to be an object, but found "+JSON.stringify(outputSentenceCounts));
var elements = []; // elements for summation
// if (inputSentenceCounts!==Object(inputSentenceCounts))
// throw new Error("expected inputSentenceCounts to be an object, but found "+JSON.stringify(inputSentenceCounts));
// if (outputSentenceCounts!==Object(outputSentenceCounts))
// throw new Error("expected outputSentenceCounts to be an object, but found "+JSON.stringify(outputSentenceCounts));
for (var feature in this.outputFeatures) {
var logFeatureGivenInput = this.logProbFeatureGivenSentence(feature, inputSentenceCounts);
// if (isNaN(logFeatureGivenInput)||!isFinite(logFeatureGivenInput)) throw new Error("logFeatureGivenInput is "+logFeatureGivenInput);
var probFeatureGivenInput = Math.exp(logFeatureGivenInput);
var logFeatureGivenOutput = this.outputLanguageModel.logProbWordGivenSentence(feature, outputSentenceCounts);
// if (isNaN(logFeatureGivenOutput)||!isFinite(logFeatureGivenOutput)) throw new Error("logFeatureGivenOutput ("+feature+", "+outputSentenceCounts+") is "+logFeatureGivenOutput);
var element = probFeatureGivenInput * (logFeatureGivenInput - logFeatureGivenOutput);
// if (isNaN(element)||!isFinite(element)) throw new Error(probFeatureGivenInput+" * ("+logFeatureGivenInput+" - "+logFeatureGivenOutput+") = "+element);
elements.push(element)
}
return elements.reduce(function(memo, num){ return memo + num; }, 0);
},
/**
* Calculate the similarity scores between the given input sentence and all output sentences in the corpus, sorted from high (most similar) to low (least similar).
* Note: similarity = - divergence
*/
similarities: function(inputSentenceCounts) {
var sims = [];
for (var i in this.outputLanguageModel.dataset) {
var output = extend({}, this.outputLanguageModel.dataset[i]);
delete output['_total'];
sims.push({
output: output,
similarity: -this.divergence(inputSentenceCounts, output)
});
}
sims.sort(function(a,b) {
return b.similarity-a.similarity;
});
return sims;
},
/**
* @param feature a single feature (-word) from the OUTPUT domain.
* @param givenSentenceCounts a hash that represents a sentence from the INPUT domain.
*/
logProbFeatureGivenSentence: function(feature, givenSentenceCounts) { // (5) P(f|W) = ...
if (!givenSentenceCounts)
throw new Error("no givenSentenceCounts");
var logSentenceAndFeature = this.logProbSentenceAndFeatureGivenDataset(feature,givenSentenceCounts);
if (isNaN(logSentenceAndFeature)||!isFinite(logSentenceAndFeature)) throw new Error("logSentenceAndFeature is "+logSentenceAndFeature);
var logSentence = this.inputLanguageModel.logProbSentenceGivenDataset(givenSentenceCounts);
if (isNaN(logSentence)||!isFinite(logSentence)) throw new Error("logSentence is "+logSentence);
//console.log("\t\t(5) "+feature+": "+Math.exp(logSentenceAndFeature)*81+" / "+Math.exp(logSentence)*81+" = "+Math.exp((logSentenceAndFeature - logSentence)));
return logSentenceAndFeature - logSentence;
},
/**
* @param feature a single feature (-word) from the OUTPUT domain.
* @param sentenceCounts a hash that represents a sentence from the INPUT domain.
* @return the joint probability of the output feature and the input sentence.
*/
logProbSentenceAndFeatureGivenDataset: function(feature, sentenceCounts) { // (2') log P(f,w1...wn) = ...
if (!sentenceCounts)
throw new Error("no sentenceCounts");
var logProducts = [];
for (var i = 0; i<this.inputLanguageModel.dataset.length; ++i) {
logProducts.push(
this.inputLanguageModel .logProbSentenceGivenSentence(sentenceCounts, this.inputLanguageModel.dataset[i]) +
this.outputLanguageModel.logProbWordGivenSentence(feature, this.outputLanguageModel.dataset[i])
);
}
var logSentenceLikelihood = logSumExp(logProducts);
return logSentenceLikelihood - Math.log(this.inputLanguageModel.dataset.length); // The last element is not needed in practice (see eq. (5))
},
toJSON: function() {
return {
inputLanguageModel: this.inputLanguageModel.toJSON(),
outputLanguageModel: this.outputLanguageModel.toJSON(),
};
},
fromJSON: function(json) {
this.inputLanguageModel.fromJSON(json.inputLanguageModel);
this.outputLanguageModel.fromJSON(json.outputLanguageModel);
}
}
module.exports = CrossLanguageModel;
if (process.argv[1] === __filename) {
console.log("CrossLanguageModel demo start");
var model = new CrossLanguageModel({
smoothingFactor : 0.9,
});
var wordcounts = require('./wordcounts');
model.trainBatch([
{input: wordcounts("I want aa"), output: wordcounts("a")},
{input: wordcounts("I want bb"), output: wordcounts("b")},
{input: wordcounts("I want cc"), output: wordcounts("c")},
]);
var assertProbSentence = function(actual, expected) {
if (Math.abs(actual-expected)/expected>0.01) {
console.warn("Received "+actual+" but expected "+expected);
}
}
var show = function(sentence) {
console.log(sentence+": ");
console.dir(model.similarities(wordcounts(sentence)));
// console.log(model.similarities(wordcounts(sentence)).map(function(sim) {
// var output = "";
// for (f in sim.output)
// if (f!='_total')
// output += (f+" ");
// return {output:output, divergence:-sim.similarity};
// }));
}
show("I want");
show("I want nothing");
show("I want aa");
show("I want bb");
show("I want aa and bb");
show("I want aa , bb and cc");
show("I want aa bb cc");
console.log("CrossLanguageModel demo end");
}