Skip to content

Commit

Permalink
Enhance HSDataset with dropout probability on history and add extract…
Browse files Browse the repository at this point in the history
…Prefixes method
bab2min committed Dec 17, 2024
1 parent d00519d commit 7a26d4d
Showing 5 changed files with 200 additions and 24 deletions.
5 changes: 4 additions & 1 deletion include/kiwi/Dataset.h
Original file line number Diff line number Diff line change
@@ -49,6 +49,7 @@ namespace kiwi
std::unique_ptr<utils::ThreadPool> workers;
std::shared_ptr<KiwiBuilder> dummyBuilder;
std::discrete_distribution<> dropout;
std::bernoulli_distribution dropoutOnHistory;
std::mt19937_64 rng;
Vector<ThreadLocal> locals;
Vector<size_t> shuffledIdx;
@@ -71,7 +72,7 @@ namespace kiwi
size_t _next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, float& restLmOut, uint32_t& restLmCntOut);

public:
HSDataset(size_t _batchSize = 0, size_t _causalContextSize = 0, size_t _windowSize = 0, size_t _workers = 0, double _dropoutProb = 0);
HSDataset(size_t _batchSize = 0, size_t _causalContextSize = 0, size_t _windowSize = 0, size_t _workers = 0, double _dropoutProb = 0, double _dropoutProbOnHistory = 0);
~HSDataset();
HSDataset(const HSDataset&) = delete;
HSDataset(HSDataset&&) /*noexcept*/;
@@ -101,5 +102,7 @@ namespace kiwi

Range<Vector<uint32_t>::const_iterator> getSent(size_t idx) const;
std::vector<uint32_t> getAugmentedSent(size_t idx);

std::vector<std::pair<std::vector<uint32_t>, size_t>> extractPrefixes(size_t minCnt, size_t maxLength, size_t numWorkers = 1) const;
};
}
8 changes: 8 additions & 0 deletions include/kiwi/Kiwi.h
Original file line number Diff line number Diff line change
@@ -799,11 +799,19 @@ namespace kiwi
return build(getDefaultTypoSet(typos), typoCostThreshold);
}

void convertHSData(
const std::vector<std::string>& inputPathes,
const std::string& outputPath,
const std::string& morphemeDefPath = {},
size_t morphemeDefMinCnt = 0
) const;

using TokenFilter = std::function<bool(const std::u16string&, POSTag)>;

HSDataset makeHSDataset(const std::vector<std::string>& inputPathes,
size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers,
double dropoutProb = 0,
double dropoutProbOnHistory = 0,
const TokenFilter& tokenFilter = {},
const TokenFilter& windowFilter = {},
double splitRatio = 0,
47 changes: 44 additions & 3 deletions src/Dataset.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#include <kiwi/Dataset.h>
#include <kiwi/SubstringExtractor.h>
#include "RaggedVector.hpp"

using namespace kiwi;

HSDataset::HSDataset(size_t _batchSize, size_t _causalContextSize, size_t _windowSize, size_t _workers, double _dropoutProb)
HSDataset::HSDataset(size_t _batchSize, size_t _causalContextSize, size_t _windowSize, size_t _workers,
double _dropoutProb, double _dropoutProbOnHistory)
: workers{ _workers ? make_unique<utils::ThreadPool>(_workers) : nullptr },
dropout{ {1 - _dropoutProb * 3, _dropoutProb, _dropoutProb, _dropoutProb} },
dropoutOnHistory{ _dropoutProbOnHistory },
locals( _workers ? workers->size() : 1),
batchSize{ _batchSize },
causalContextSize{ _causalContextSize },
@@ -149,8 +152,19 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode,
{
for (size_t j = 0; j < causalContextSize; ++j)
{
local.inData.emplace_back(i + j < causalContextSize ?
nonVocab : tokenToVocab[tokens[i + j - causalContextSize]]);
if (i + j < causalContextSize)
{
local.inData.emplace_back(nonVocab);
}
else
{
auto t = tokens[i + j - causalContextSize];
if (dropoutOnHistory.p() > 0 && dropoutOnHistory(local.rng))
{
t = getDefaultMorphemeId((*morphemes)[t].tag);
}
local.inData.emplace_back(tokenToVocab[t]);
}
}
}
if (windowSize)
@@ -347,3 +361,30 @@ std::vector<uint32_t> HSDataset::getAugmentedSent(size_t idx)
ret.emplace_back(*sent.rbegin());
return ret;
}

std::vector<std::pair<std::vector<uint32_t>, size_t>> kiwi::HSDataset::extractPrefixes(size_t minCnt, size_t maxLength, size_t numWorkers) const
{
using Pair = std::pair<std::vector<uint32_t>, size_t>;
std::vector<Pair> ret;
PrefixCounter counter{ maxLength, minCnt, numWorkers };
for (auto sent : sents.get())
{
counter.addArray(&*sent.begin(), &*sent.end());
}
auto trie = counter.count();
trie.traverse([&](size_t cnt, const std::vector<uint32_t>& prefix)
{
if (cnt < minCnt) return;
if (std::find_if(prefix.begin(), prefix.end(), [](uint32_t t) { return t < 2; }) != prefix.end())
{
return;
}
ret.emplace_back(prefix, cnt);
});

std::sort(ret.begin(), ret.end(), [](const Pair& a, const Pair& b)
{
return a.second > b.second;
});
return ret;
}
80 changes: 77 additions & 3 deletions src/KiwiBuilder.cpp
Original file line number Diff line number Diff line change
@@ -598,6 +598,7 @@ void KiwiBuilder::_addCorpusTo(
auto fields = split(wstr, u'\t');
if (fields.size() < 2) continue;

size_t mergedIndex = -1;
for (size_t i = 1; i < fields.size(); i += 2)
{
auto f = normalizeHangul(fields[i]);
@@ -618,6 +619,24 @@ void KiwiBuilder::_addCorpusTo(
alreadyPrintError = true;
}

if (t == POSTag::z_siot || i == mergedIndex)
{
continue;
}

if (i + 6 < fields.size() && toPOSTag(fields[i + 3]) == POSTag::z_siot)
{
auto nf = f;
nf += normalizeHangul(fields[i + 2]);
nf += normalizeHangul(fields[i + 4]);
if (morphMap.count(make_tuple(nf, 0, POSTag::nng)))
{
f = nf;
t = POSTag::nng;
mergedIndex = i + 4;
}
}

if (f[0] == u'' && fields[i + 1][0] == 'E')
{
f[0] = u'';
@@ -2247,9 +2266,47 @@ vector<WordInfo> KiwiBuilder::extractAddWords(const U16MultipleReader& reader, s
return words;
}

void KiwiBuilder::convertHSData(
const vector<string>& inputPathes,
const string& outputPath,
const string& morphemeDefPath,
size_t morphemeDefMinCnt
) const
{
unique_ptr<KiwiBuilder> dummyBuilder;
const KiwiBuilder* srcBuilder = this;
MorphemeMap realMorph;
if (morphemeDefPath.empty())
{
realMorph = restoreMorphemeMap();
}
else
{
dummyBuilder = make_unique<KiwiBuilder>();
dummyBuilder->initMorphemes();
ifstream ifs;
realMorph = dummyBuilder->loadMorphemesFromTxt(openFile(ifs, morphemeDefPath), [&](POSTag tag, float cnt)
{
return cnt >= morphemeDefMinCnt;
});
srcBuilder = dummyBuilder.get();
}

RaggedVector<uint32_t> sents;
for (auto& path : inputPathes)
{
ifstream ifs;
srcBuilder->addCorpusTo(sents, openFile(ifs, path), realMorph);
}

ofstream ofs;
sents.write_to_memory(openFile(ofs, outputPath, ios_base::binary));
}

HSDataset KiwiBuilder::makeHSDataset(const vector<string>& inputPathes,
size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers,
double dropoutProb,
double dropoutProbOnHistory,
const TokenFilter& tokenFilter,
const TokenFilter& windowFilter,
double splitRatio,
@@ -2259,7 +2316,7 @@ HSDataset KiwiBuilder::makeHSDataset(const vector<string>& inputPathes,
HSDataset* splitDataset
) const
{
HSDataset dataset{ batchSize, causalContextSize, windowSize, numWorkers, dropoutProb };
HSDataset dataset{ batchSize, causalContextSize, windowSize, numWorkers, dropoutProb, dropoutProbOnHistory };
auto& sents = dataset.sents.get();
const KiwiBuilder* srcBuilder = this;
MorphemeMap realMorph;
@@ -2301,8 +2358,25 @@ HSDataset KiwiBuilder::makeHSDataset(const vector<string>& inputPathes,

for (auto& path : inputPathes)
{
ifstream ifs;
srcBuilder->addCorpusTo(sents, openFile(ifs, path), realMorph, splitRatio, splitDataset ? &splitDataset->sents.get() : nullptr);
try
{
ifstream ifs;
auto cvtSents = RaggedVector<uint32_t>::from_memory(openFile(ifs, path, ios_base::binary));
if (splitRatio > 0)
{
throw invalid_argument("splitDataset cannot be used with binary input");
}
for (auto s : cvtSents)
{
sents.emplace_back();
sents.insert_data(s.begin(), s.end());
}
}
catch (const runtime_error&)
{
ifstream ifs;
srcBuilder->addCorpusTo(sents, openFile(ifs, path), realMorph, splitRatio, splitDataset ? &splitDataset->sents.get() : nullptr);
}
}
size_t tokenSize = sents.raw().empty() ? 0 : *max_element(sents.raw().begin(), sents.raw().end()) + 1;

84 changes: 67 additions & 17 deletions src/RaggedVector.hpp
Original file line number Diff line number Diff line change
@@ -94,8 +94,8 @@ namespace kiwi

auto operator[](size_t idx) -> Range<decltype(data.begin())>
{
size_t b = idx < ptrs.size() ? ptrs[idx] : data.size();
size_t e = idx + 1 < ptrs.size() ? ptrs[idx + 1] : data.size();
const size_t b = idx < ptrs.size() ? ptrs[idx] : data.size();
const size_t e = idx + 1 < ptrs.size() ? ptrs[idx + 1] : data.size();
return { data.begin() + b, data.begin() + e };
}

@@ -142,31 +142,81 @@ namespace kiwi
return { *this, size() };
}

utils::MemoryObject toMemory() const
utils::MemoryObject to_memory() const
{
utils::MemoryOwner ret{ sizeof(size_t) * 2 + sizeof(ValueTy) * data.size() + sizeof(size_t) * ptrs.size() };
utils::MemoryOwner ret{ 4 + sizeof(uint64_t) * 2 + sizeof(ValueTy) * data.size() + sizeof(uint64_t) * ptrs.size() };
utils::omstream ostr{ (char*)ret.get(), (ptrdiff_t)ret.size()};
size_t s;
write_to_memory(ostr);
return ret;
}

void write_to_memory(std::ostream& ostr) const
{
if (!ostr.write("KIRV", 4))
{
throw std::runtime_error("Failed to write RaggedVector memory object");
}
uint64_t s;
s = data.size();
ostr.write((const char*)&s, sizeof(size_t));
if (!ostr.write((const char*)&s, sizeof(uint64_t)))
{
throw std::runtime_error("Failed to write RaggedVector memory object");
}

s = ptrs.size();
ostr.write((const char*)&s, sizeof(size_t));
ostr.write((const char*)data.data(), sizeof(ValueTy) * data.size());
ostr.write((const char*)ptrs.data(), sizeof(size_t) * ptrs.size());
return ret;
if (!ostr.write((const char*)&s, sizeof(uint64_t)))
{
throw std::runtime_error("Failed to write RaggedVector memory object");
}

if (!ostr.write((const char*)data.data(), sizeof(ValueTy) * data.size()))
{
throw std::runtime_error("Failed to write RaggedVector memory object");
}

if (!ostr.write((const char*)ptrs.data(), sizeof(uint64_t) * ptrs.size()))
{
throw std::runtime_error("Failed to write RaggedVector memory object");
}
}

static RaggedVector fromMemory(std::istream& istr)
static RaggedVector from_memory(std::istream& istr)
{
RaggedVector ret;
size_t s;
istr.read((char*)&s, sizeof(size_t));
char buf[4];
if (!istr.read(buf, 4))
{
throw std::runtime_error("Invalid RaggedVector memory object");
}

if (memcmp(buf, "KIRV", 4) != 0)
{
throw std::runtime_error("Invalid RaggedVector memory object");
}

uint64_t s;
if (!istr.read((char*)&s, sizeof(uint64_t)))
{
throw std::runtime_error("Invalid RaggedVector memory object");
}
ret.data.resize(s);
istr.read((char*)&s, sizeof(size_t));

if (!istr.read((char*)&s, sizeof(uint64_t)))
{
throw std::runtime_error("Invalid RaggedVector memory object");
}
ret.ptrs.resize(s);
istr.read((char*)ret.data.data(), sizeof(ValueTy) * ret.data.size());
istr.read((char*)ret.ptrs.data(), sizeof(size_t) * ret.ptrs.size());

if (!istr.read((char*)ret.data.data(), sizeof(ValueTy) * ret.data.size()))
{
throw std::runtime_error("Invalid RaggedVector memory object");
}

if (!istr.read((char*)ret.ptrs.data(), sizeof(uint64_t) * ret.ptrs.size()))
{
throw std::runtime_error("Invalid RaggedVector memory object");
}
return ret;
}
};
}
}

0 comments on commit 7a26d4d

Please sign in to comment.