Skip to content

Commit ebe94df

Browse files
committed
modified
1 parent 7ad7e84 commit ebe94df

File tree

3 files changed

+22
-23
lines changed

3 files changed

+22
-23
lines changed

test.cpp.seq2seq

+6-12
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ void toPVariable(PVariable x1, float *X){
3030

3131

3232

33-
WordEmbed *load_data(string filename, int vocab_size, bool addEOS, bool addSOS){
33+
WordEmbed *load_data(string filename, int vocab_size, bool tokenize, bool addEOS){
3434

3535

3636
std::ifstream reading_file(filename, std::ios::in);
@@ -48,7 +48,7 @@ WordEmbed *load_data(string filename, int vocab_size, bool addEOS, bool addSOS){
4848

4949
WordEmbed *wd = new WordEmbed(vocab_size);
5050

51-
wd->addSentences(sequences, false, addEOS, addSOS);
51+
wd->addSentences(sequences, tokenize, addEOS);
5252

5353
return wd;
5454
}
@@ -140,7 +140,6 @@ PVariable attention_hidden_state(PVariable h, PVariable a){
140140

141141
return model.G("attention_linear_tanh")->forward(attention_plus);
142142
}
143-
///////////////////////////
144143

145144

146145
int get_max_vocab_size(vector<vector<int>> &seqs_ids, int batch_size, int k){
@@ -155,8 +154,6 @@ int get_max_vocab_size(vector<vector<int>> &seqs_ids, int batch_size, int k){
155154
vector<PVariable> encoder(vector<vector<int>> &seqs_ids_ja, WordEmbed *wd_ja, int batch_size, int vocab_size, int k){
156155

157156
int max_vocab_size_ja = get_max_vocab_size(seqs_ids_ja, batch_size, k);
158-
//cout << "max_vocab_size_ja:" << max_vocab_size_ja << endl;
159-
160157

161158
vector<PVariable> src_hidden_states;
162159

@@ -173,7 +170,6 @@ vector<PVariable> encoder(vector<vector<int>> &seqs_ids_ja, WordEmbed *wd_ja, in
173170
reverse(word_ids.begin(), word_ids.end());
174171

175172
bool ignore = false;
176-
//if (word_ids[j] == wd_ja->PAD_ID) ignore = true;
177173
wd_ja->toOneHot(vocab_size, data_ja, word_ids[j], batch_idx, ignore);
178174
batch_idx++;
179175
}
@@ -308,7 +304,6 @@ PVariable forward_one_step(vector<vector<int>> &seqs_ids_ja, vector<vector<int>>
308304
wd_en->padding(word_ids, max_vocab_size_en);
309305

310306
bool ignore = false;
311-
//if (word_ids[j] == wd_en->PAD_ID) ignore = true;
312307
wd_en->toOneHot(vocab_size, data_en, word_ids[j], batch_idx, ignore);
313308
batch_idx++;
314309
}
@@ -342,11 +337,10 @@ int main(){
342337
float clip_grad_threshold = 0;
343338
float learning_rate = 0.001; //ADAM
344339

345-
int epoch = 100;
346-
340+
int epoch = 20;
347341

348-
WordEmbed *wd_ja = load_data("tanaka_corpus_j_10000.txt", vocab_size, false, false);
349-
WordEmbed *wd_en = load_data("tanaka_corpus_e_10000.txt", vocab_size, true, false);
342+
WordEmbed *wd_ja = load_data("tanaka_corpus_j_10000.txt.train", vocab_size, true, false);
343+
WordEmbed *wd_en = load_data("tanaka_corpus_e_10000.txt.train", vocab_size, true, true);
350344

351345
vector<vector<int>> seqs_ids_ja = wd_ja->getSequencesIds();
352346
vector<vector<int>> seqs_ids_en = wd_en->getSequencesIds();
@@ -471,7 +465,7 @@ int main(){
471465
((FullLSTM2 *) model.G("lstm_ja"))->reset_state();
472466
((FullLSTM2 *) model.G("lstm_en"))->reset_state();
473467
}
474-
468+
475469
delete wd_ja;
476470
delete wd_en;
477471

tokenizer.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ class Tokenizer {
3333
public:
3434
Tokenizer(){
3535
//tagger = MeCab::createTagger("-Owakati");
36-
tagger = MeCab::createTagger("-xunknown");
36+
//tagger = MeCab::createTagger("-xunknown -d /usr/local/lib/mecab/dic/mecab-ipadic-neologd");
37+
tagger = MeCab::createTagger("-d /usr/local/lib/mecab/dic/mecab-ipadic-neologd");
3738
}
3839
~Tokenizer(){
3940
delete tagger;
@@ -50,12 +51,12 @@ class Tokenizer {
5051

5152
for (; node; node = node->next) {
5253
string feature(node->feature);
53-
if (feature.find("名詞")==0 || feature.find("未知語")==0){
54+
//if (feature.find("名詞")==0 || feature.find("未知語")==0){
5455
strcpy(buf,node->surface);
5556
buf[node->length]='\0';
5657
string surface(buf);
5758
result.push_back(surface);
58-
}
59+
//}
5960
}
6061
return result;
6162
}

word_embed.h

+12-8
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ class WordEmbed {
8989
}
9090

9191

92-
void addSentences(vector<string> seqs, bool tokenize, bool addEOS, bool addSOS){
92+
void addSentences(vector<string> seqs, bool tokenize, bool addEOS){
9393
for (auto s : seqs){
94-
add(s, tokenize, addEOS, addSOS);
94+
add(s, tokenize, addEOS);
9595
}
9696

9797
vector<pair<string, int> > pairs(words_count.size());
@@ -129,23 +129,27 @@ class WordEmbed {
129129

130130
}
131131

132-
void add(string sentence, bool tokenize, bool addEOS, bool addSOS){
132+
void add(string sentence, bool tokenize, bool addEOS){
133133

134134
if (sentence == "") return;
135135

136-
if (addSOS) sentence = "<sos> " + sentence;
137-
if (addEOS) sentence += " <eos>";
138136

139-
vector<string> words;
137+
vector<string> words, words_final;
140138
if (tokenize) words = token.parse(sentence);
141139
else words = split(sentence, ' ');
142140

143-
for (auto w : words) {
141+
if (addEOS) words.push_back("<eos>");
142+
143+
for (auto w : words){
144+
if (w != "") words_final.push_back(w);
145+
}
146+
147+
for (auto w : words_final) {
144148
if (words_count.count(w) == 0) words_count[w] = 1;
145149
else words_count[w] += 1;
146150
}
147151

148-
sequences.push_back(words);
152+
sequences.push_back(words_final);
149153
}
150154

151155

0 commit comments

Comments
 (0)