From 7a26d4dc8f1a266a000f393a929532db0e672a3c Mon Sep 17 00:00:00 2001 From: bab2min Date: Tue, 17 Dec 2024 21:42:49 +0900 Subject: [PATCH] Enhance HSDataset with dropout probability on history and add extractPrefixes method --- include/kiwi/Dataset.h | 5 ++- include/kiwi/Kiwi.h | 8 ++++ src/Dataset.cpp | 47 +++++++++++++++++++++-- src/KiwiBuilder.cpp | 80 ++++++++++++++++++++++++++++++++++++++-- src/RaggedVector.hpp | 84 +++++++++++++++++++++++++++++++++--------- 5 files changed, 200 insertions(+), 24 deletions(-) diff --git a/include/kiwi/Dataset.h b/include/kiwi/Dataset.h index 32be4879..e45738d5 100644 --- a/include/kiwi/Dataset.h +++ b/include/kiwi/Dataset.h @@ -49,6 +49,7 @@ namespace kiwi std::unique_ptr workers; std::shared_ptr dummyBuilder; std::discrete_distribution<> dropout; + std::bernoulli_distribution dropoutOnHistory; std::mt19937_64 rng; Vector locals; Vector shuffledIdx; @@ -71,7 +72,7 @@ namespace kiwi size_t _next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, float& restLmOut, uint32_t& restLmCntOut); public: - HSDataset(size_t _batchSize = 0, size_t _causalContextSize = 0, size_t _windowSize = 0, size_t _workers = 0, double _dropoutProb = 0); + HSDataset(size_t _batchSize = 0, size_t _causalContextSize = 0, size_t _windowSize = 0, size_t _workers = 0, double _dropoutProb = 0, double _dropoutProbOnHistory = 0); ~HSDataset(); HSDataset(const HSDataset&) = delete; HSDataset(HSDataset&&) /*noexcept*/; @@ -101,5 +102,7 @@ namespace kiwi Range::const_iterator> getSent(size_t idx) const; std::vector getAugmentedSent(size_t idx); + + std::vector, size_t>> extractPrefixes(size_t minCnt, size_t maxLength, size_t numWorkers = 1) const; }; } diff --git a/include/kiwi/Kiwi.h b/include/kiwi/Kiwi.h index b469bf82..1c2ee7d2 100644 --- a/include/kiwi/Kiwi.h +++ b/include/kiwi/Kiwi.h @@ -799,11 +799,19 @@ namespace kiwi return build(getDefaultTypoSet(typos), typoCostThreshold); } + void convertHSData( + const std::vector& inputPathes, + const std::string& outputPath, + const std::string& morphemeDefPath = {}, + size_t morphemeDefMinCnt = 0 + ) const; + using TokenFilter = std::function; HSDataset makeHSDataset(const std::vector& inputPathes, size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers, double dropoutProb = 0, + double dropoutProbOnHistory = 0, const TokenFilter& tokenFilter = {}, const TokenFilter& windowFilter = {}, double splitRatio = 0, diff --git a/src/Dataset.cpp b/src/Dataset.cpp index f3e47edd..18d2b03d 100644 --- a/src/Dataset.cpp +++ b/src/Dataset.cpp @@ -1,11 +1,14 @@ #include +#include #include "RaggedVector.hpp" using namespace kiwi; -HSDataset::HSDataset(size_t _batchSize, size_t _causalContextSize, size_t _windowSize, size_t _workers, double _dropoutProb) +HSDataset::HSDataset(size_t _batchSize, size_t _causalContextSize, size_t _windowSize, size_t _workers, + double _dropoutProb, double _dropoutProbOnHistory) : workers{ _workers ? make_unique(_workers) : nullptr }, dropout{ {1 - _dropoutProb * 3, _dropoutProb, _dropoutProb, _dropoutProb} }, + dropoutOnHistory{ _dropoutProbOnHistory }, locals( _workers ? workers->size() : 1), batchSize{ _batchSize }, causalContextSize{ _causalContextSize }, @@ -149,8 +152,19 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, { for (size_t j = 0; j < causalContextSize; ++j) { - local.inData.emplace_back(i + j < causalContextSize ? - nonVocab : tokenToVocab[tokens[i + j - causalContextSize]]); + if (i + j < causalContextSize) + { + local.inData.emplace_back(nonVocab); + } + else + { + auto t = tokens[i + j - causalContextSize]; + if (dropoutOnHistory.p() > 0 && dropoutOnHistory(local.rng)) + { + t = getDefaultMorphemeId((*morphemes)[t].tag); + } + local.inData.emplace_back(tokenToVocab[t]); + } } } if (windowSize) @@ -347,3 +361,30 @@ std::vector HSDataset::getAugmentedSent(size_t idx) ret.emplace_back(*sent.rbegin()); return ret; } + +std::vector, size_t>> kiwi::HSDataset::extractPrefixes(size_t minCnt, size_t maxLength, size_t numWorkers) const +{ + using Pair = std::pair, size_t>; + std::vector ret; + PrefixCounter counter{ maxLength, minCnt, numWorkers }; + for (auto sent : sents.get()) + { + counter.addArray(&*sent.begin(), &*sent.end()); + } + auto trie = counter.count(); + trie.traverse([&](size_t cnt, const std::vector& prefix) + { + if (cnt < minCnt) return; + if (std::find_if(prefix.begin(), prefix.end(), [](uint32_t t) { return t < 2; }) != prefix.end()) + { + return; + } + ret.emplace_back(prefix, cnt); + }); + + std::sort(ret.begin(), ret.end(), [](const Pair& a, const Pair& b) + { + return a.second > b.second; + }); + return ret; +} diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index 0c370065..06b47b86 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -598,6 +598,7 @@ void KiwiBuilder::_addCorpusTo( auto fields = split(wstr, u'\t'); if (fields.size() < 2) continue; + size_t mergedIndex = -1; for (size_t i = 1; i < fields.size(); i += 2) { auto f = normalizeHangul(fields[i]); @@ -618,6 +619,24 @@ void KiwiBuilder::_addCorpusTo( alreadyPrintError = true; } + if (t == POSTag::z_siot || i == mergedIndex) + { + continue; + } + + if (i + 6 < fields.size() && toPOSTag(fields[i + 3]) == POSTag::z_siot) + { + auto nf = f; + nf += normalizeHangul(fields[i + 2]); + nf += normalizeHangul(fields[i + 4]); + if (morphMap.count(make_tuple(nf, 0, POSTag::nng))) + { + f = nf; + t = POSTag::nng; + mergedIndex = i + 4; + } + } + if (f[0] == u'아' && fields[i + 1][0] == 'E') { f[0] = u'어'; @@ -2247,9 +2266,47 @@ vector KiwiBuilder::extractAddWords(const U16MultipleReader& reader, s return words; } +void KiwiBuilder::convertHSData( + const vector& inputPathes, + const string& outputPath, + const string& morphemeDefPath, + size_t morphemeDefMinCnt +) const +{ + unique_ptr dummyBuilder; + const KiwiBuilder* srcBuilder = this; + MorphemeMap realMorph; + if (morphemeDefPath.empty()) + { + realMorph = restoreMorphemeMap(); + } + else + { + dummyBuilder = make_unique(); + dummyBuilder->initMorphemes(); + ifstream ifs; + realMorph = dummyBuilder->loadMorphemesFromTxt(openFile(ifs, morphemeDefPath), [&](POSTag tag, float cnt) + { + return cnt >= morphemeDefMinCnt; + }); + srcBuilder = dummyBuilder.get(); + } + + RaggedVector sents; + for (auto& path : inputPathes) + { + ifstream ifs; + srcBuilder->addCorpusTo(sents, openFile(ifs, path), realMorph); + } + + ofstream ofs; + sents.write_to_memory(openFile(ofs, outputPath, ios_base::binary)); +} + HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers, double dropoutProb, + double dropoutProbOnHistory, const TokenFilter& tokenFilter, const TokenFilter& windowFilter, double splitRatio, @@ -2259,7 +2316,7 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, HSDataset* splitDataset ) const { - HSDataset dataset{ batchSize, causalContextSize, windowSize, numWorkers, dropoutProb }; + HSDataset dataset{ batchSize, causalContextSize, windowSize, numWorkers, dropoutProb, dropoutProbOnHistory }; auto& sents = dataset.sents.get(); const KiwiBuilder* srcBuilder = this; MorphemeMap realMorph; @@ -2301,8 +2358,25 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, for (auto& path : inputPathes) { - ifstream ifs; - srcBuilder->addCorpusTo(sents, openFile(ifs, path), realMorph, splitRatio, splitDataset ? &splitDataset->sents.get() : nullptr); + try + { + ifstream ifs; + auto cvtSents = RaggedVector::from_memory(openFile(ifs, path, ios_base::binary)); + if (splitRatio > 0) + { + throw invalid_argument("splitDataset cannot be used with binary input"); + } + for (auto s : cvtSents) + { + sents.emplace_back(); + sents.insert_data(s.begin(), s.end()); + } + } + catch (const runtime_error&) + { + ifstream ifs; + srcBuilder->addCorpusTo(sents, openFile(ifs, path), realMorph, splitRatio, splitDataset ? &splitDataset->sents.get() : nullptr); + } } size_t tokenSize = sents.raw().empty() ? 0 : *max_element(sents.raw().begin(), sents.raw().end()) + 1; diff --git a/src/RaggedVector.hpp b/src/RaggedVector.hpp index abee3880..bfac8b52 100644 --- a/src/RaggedVector.hpp +++ b/src/RaggedVector.hpp @@ -94,8 +94,8 @@ namespace kiwi auto operator[](size_t idx) -> Range { - size_t b = idx < ptrs.size() ? ptrs[idx] : data.size(); - size_t e = idx + 1 < ptrs.size() ? ptrs[idx + 1] : data.size(); + const size_t b = idx < ptrs.size() ? ptrs[idx] : data.size(); + const size_t e = idx + 1 < ptrs.size() ? ptrs[idx + 1] : data.size(); return { data.begin() + b, data.begin() + e }; } @@ -142,31 +142,81 @@ namespace kiwi return { *this, size() }; } - utils::MemoryObject toMemory() const + utils::MemoryObject to_memory() const { - utils::MemoryOwner ret{ sizeof(size_t) * 2 + sizeof(ValueTy) * data.size() + sizeof(size_t) * ptrs.size() }; + utils::MemoryOwner ret{ 4 + sizeof(uint64_t) * 2 + sizeof(ValueTy) * data.size() + sizeof(uint64_t) * ptrs.size() }; utils::omstream ostr{ (char*)ret.get(), (ptrdiff_t)ret.size()}; - size_t s; + write_to_memory(ostr); + return ret; + } + + void write_to_memory(std::ostream& ostr) const + { + if (!ostr.write("KIRV", 4)) + { + throw std::runtime_error("Failed to write RaggedVector memory object"); + } + uint64_t s; s = data.size(); - ostr.write((const char*)&s, sizeof(size_t)); + if (!ostr.write((const char*)&s, sizeof(uint64_t))) + { + throw std::runtime_error("Failed to write RaggedVector memory object"); + } + s = ptrs.size(); - ostr.write((const char*)&s, sizeof(size_t)); - ostr.write((const char*)data.data(), sizeof(ValueTy) * data.size()); - ostr.write((const char*)ptrs.data(), sizeof(size_t) * ptrs.size()); - return ret; + if (!ostr.write((const char*)&s, sizeof(uint64_t))) + { + throw std::runtime_error("Failed to write RaggedVector memory object"); + } + + if (!ostr.write((const char*)data.data(), sizeof(ValueTy) * data.size())) + { + throw std::runtime_error("Failed to write RaggedVector memory object"); + } + + if (!ostr.write((const char*)ptrs.data(), sizeof(uint64_t) * ptrs.size())) + { + throw std::runtime_error("Failed to write RaggedVector memory object"); + } } - static RaggedVector fromMemory(std::istream& istr) + static RaggedVector from_memory(std::istream& istr) { RaggedVector ret; - size_t s; - istr.read((char*)&s, sizeof(size_t)); + char buf[4]; + if (!istr.read(buf, 4)) + { + throw std::runtime_error("Invalid RaggedVector memory object"); + } + + if (memcmp(buf, "KIRV", 4) != 0) + { + throw std::runtime_error("Invalid RaggedVector memory object"); + } + + uint64_t s; + if (!istr.read((char*)&s, sizeof(uint64_t))) + { + throw std::runtime_error("Invalid RaggedVector memory object"); + } ret.data.resize(s); - istr.read((char*)&s, sizeof(size_t)); + + if (!istr.read((char*)&s, sizeof(uint64_t))) + { + throw std::runtime_error("Invalid RaggedVector memory object"); + } ret.ptrs.resize(s); - istr.read((char*)ret.data.data(), sizeof(ValueTy) * ret.data.size()); - istr.read((char*)ret.ptrs.data(), sizeof(size_t) * ret.ptrs.size()); + + if (!istr.read((char*)ret.data.data(), sizeof(ValueTy) * ret.data.size())) + { + throw std::runtime_error("Invalid RaggedVector memory object"); + } + + if (!istr.read((char*)ret.ptrs.data(), sizeof(uint64_t) * ret.ptrs.size())) + { + throw std::runtime_error("Invalid RaggedVector memory object"); + } return ret; } }; -} \ No newline at end of file +}