From 5ebbb989d84f4551e6d2d4375f042105dac503e4 Mon Sep 17 00:00:00 2001 From: Chase Geigle Date: Mon, 30 Apr 2018 15:03:19 -0500 Subject: [PATCH] [topics]: Add bindings for LDA inferencer classes. --- deps/meta | 2 +- src/metapy_topics.cpp | 64 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/deps/meta b/deps/meta index d11e837..be553be 160000 --- a/deps/meta +++ b/deps/meta @@ -1 +1 @@ -Subproject commit d11e83748fed852c1ded61036010bae6b7984bd5 +Subproject commit be553be6a3f9136952c0e786fdbc6de9d5f8570e diff --git a/src/metapy_topics.cpp b/src/metapy_topics.cpp index 4556449..1154632 100644 --- a/src/metapy_topics.cpp +++ b/src/metapy_topics.cpp @@ -61,6 +61,14 @@ void metapy_bind_topics(py::module& m) }) .def("num_topics", &topics::lda_model::num_topics); + py::class_{m_topics, "LDAInferencer"} + .def("term_distribution", + [](const topics::inferencer& inf, topic_id k) { + return py_multinomial{inf.term_distribution(k)}; + }, + py::arg("k")) + .def("num_topics", &topics::inferencer::num_topics); + py::class_{m_topics, "LDACollapsedVB"} .def(py::init(), py::keep_alive<0, 1>(), py::arg("docs"), py::arg("num_topics"), @@ -72,6 +80,31 @@ void metapy_bind_topics(py::module& m) }, py::arg("num_iters"), py::arg("convergence") = 1e-3); + py::class_{m_topics, + "CVBInferencer"} + .def("__init__", + [](topics::inferencer& inf, const std::string& cfgfile) { + py::gil_scoped_release release; + auto config = cpptoml::parse_file(cfgfile); + new (&inf) topics::inferencer(*config); + }, + py::arg("cfg_file")) + .def("__init__", + [](topics::inferencer& inf, const std::string& topicsfile, + double alpha) { + py::gil_scoped_release release; + std::ifstream topics_stream{topicsfile}; + new (&inf) topics::inferencer(topics_stream, alpha); + }, + py::arg("topics_file"), py::arg("alpha")) + .def("infer", + [](const topics::lda_cvb::inferencer& inf, + const learn::feature_vector& doc, std::size_t max_iters, + double convergence) { + return py_multinomial{inf(doc, max_iters, convergence)}; + }, + py::arg("doc"), py::arg("max_iters"), py::arg("convergence")); + py::class_{m_topics, "LDAGibbs"} .def(py::init(), py::keep_alive<0, 1>(), py::arg("docs"), py::arg("num_topics"), @@ -84,6 +117,31 @@ void metapy_bind_topics(py::module& m) }, py::arg("num_iters"), py::arg("convergence") = 1e-6); + py::class_{ + m_topics, "GibbsInferencer"} + .def("__init__", + [](topics::inferencer& inf, const std::string& cfgfile) { + auto config = cpptoml::parse_file(cfgfile); + new (&inf) topics::inferencer(*config); + }, + py::arg("cfg_file")) + .def("__init__", + [](topics::inferencer& inf, const std::string& topicsfile, + double alpha) { + std::ifstream topics_stream{topicsfile}; + new (&inf) topics::inferencer(topics_stream, alpha); + }, + py::arg("topics_file"), py::arg("alpha")) + + .def("infer", + [](const topics::lda_gibbs::inferencer& inf, + const learn::feature_vector& doc, std::size_t num_iters, + std::size_t seed) { + random::xoroshiro128 rng{seed}; + return py_multinomial{inf(doc, num_iters, rng)}; + }, + py::arg("doc"), py::arg("max_iters"), py::arg("rng_seed")); + py::class_{ m_topics, "LDAParallelGibbs"} .def(py::init(), @@ -127,8 +185,10 @@ void metapy_bind_topics(py::module& m) new (&model) topics::topic_model(theta, phi); }) - .def("top_k", [](const topics::topic_model& model, topic_id tid, - std::size_t k) { return model.top_k(tid, k); }, + .def("top_k", + [](const topics::topic_model& model, topic_id tid, std::size_t k) { + return model.top_k(tid, k); + }, py::arg("tid"), py::arg("k") = 10) .def("top_k", [](const topics::topic_model& model, topic_id tid, std::size_t k,