@@ -30,7 +30,7 @@ void toPVariable(PVariable x1, float *X){
30
30
31
31
32
32
33
- WordEmbed *load_data(string filename, int vocab_size, bool addEOS , bool addSOS ){
33
+ WordEmbed *load_data(string filename, int vocab_size, bool tokenize , bool addEOS ){
34
34
35
35
36
36
std::ifstream reading_file(filename, std::ios::in);
@@ -48,7 +48,7 @@ WordEmbed *load_data(string filename, int vocab_size, bool addEOS, bool addSOS){
48
48
49
49
WordEmbed *wd = new WordEmbed(vocab_size);
50
50
51
- wd->addSentences(sequences, false , addEOS, addSOS );
51
+ wd->addSentences(sequences, tokenize , addEOS);
52
52
53
53
return wd;
54
54
}
@@ -140,7 +140,6 @@ PVariable attention_hidden_state(PVariable h, PVariable a){
140
140
141
141
return model.G("attention_linear_tanh")->forward(attention_plus);
142
142
}
143
- ///////////////////////////
144
143
145
144
146
145
int get_max_vocab_size(vector<vector<int>> &seqs_ids, int batch_size, int k){
@@ -155,8 +154,6 @@ int get_max_vocab_size(vector<vector<int>> &seqs_ids, int batch_size, int k){
155
154
vector<PVariable> encoder(vector<vector<int>> &seqs_ids_ja, WordEmbed *wd_ja, int batch_size, int vocab_size, int k){
156
155
157
156
int max_vocab_size_ja = get_max_vocab_size(seqs_ids_ja, batch_size, k);
158
- //cout << "max_vocab_size_ja:" << max_vocab_size_ja << endl;
159
-
160
157
161
158
vector<PVariable> src_hidden_states;
162
159
@@ -173,7 +170,6 @@ vector<PVariable> encoder(vector<vector<int>> &seqs_ids_ja, WordEmbed *wd_ja, in
173
170
reverse(word_ids.begin(), word_ids.end());
174
171
175
172
bool ignore = false;
176
- //if (word_ids[j] == wd_ja->PAD_ID) ignore = true;
177
173
wd_ja->toOneHot(vocab_size, data_ja, word_ids[j], batch_idx, ignore);
178
174
batch_idx++;
179
175
}
@@ -308,7 +304,6 @@ PVariable forward_one_step(vector<vector<int>> &seqs_ids_ja, vector<vector<int>>
308
304
wd_en->padding(word_ids, max_vocab_size_en);
309
305
310
306
bool ignore = false;
311
- //if (word_ids[j] == wd_en->PAD_ID) ignore = true;
312
307
wd_en->toOneHot(vocab_size, data_en, word_ids[j], batch_idx, ignore);
313
308
batch_idx++;
314
309
}
@@ -342,11 +337,10 @@ int main(){
342
337
float clip_grad_threshold = 0;
343
338
float learning_rate = 0.001; //ADAM
344
339
345
- int epoch = 100;
346
-
340
+ int epoch = 20;
347
341
348
- WordEmbed *wd_ja = load_data("tanaka_corpus_j_10000.txt", vocab_size, false , false);
349
- WordEmbed *wd_en = load_data("tanaka_corpus_e_10000.txt", vocab_size, true, false );
342
+ WordEmbed *wd_ja = load_data("tanaka_corpus_j_10000.txt.train ", vocab_size, true , false);
343
+ WordEmbed *wd_en = load_data("tanaka_corpus_e_10000.txt.train ", vocab_size, true, true );
350
344
351
345
vector<vector<int>> seqs_ids_ja = wd_ja->getSequencesIds();
352
346
vector<vector<int>> seqs_ids_en = wd_en->getSequencesIds();
@@ -471,7 +465,7 @@ int main(){
471
465
((FullLSTM2 *) model.G("lstm_ja"))->reset_state();
472
466
((FullLSTM2 *) model.G("lstm_en"))->reset_state();
473
467
}
474
-
468
+
475
469
delete wd_ja;
476
470
delete wd_en;
477
471
0 commit comments