diff --git a/litstudy/__init__.py b/litstudy/__init__.py index fcb125e..d96ce64 100644 --- a/litstudy/__init__.py +++ b/litstudy/__init__.py @@ -61,6 +61,7 @@ build_corpus, train_nmf_model, train_lda_model, + train_elda_model, compute_word_distribution, calculate_embedding, ) # noqa: F401 diff --git a/litstudy/nlp.py b/litstudy/nlp.py index b0486fa..b939e6e 100644 --- a/litstudy/nlp.py +++ b/litstudy/nlp.py @@ -311,16 +311,69 @@ def train_lda_model(corpus: Corpus, num_topics, seed=0, **kwargs) -> TopicModel: :param num_topics: The number of topics to train. :param seed: The seed used for random number generation. - :param kwargs: Arguments passed to `gensim.models.lda.LdaModel`. + :param kwargs: Arguments passed to `gensim.models.lda.LdaModel` (gensim3) + or `gensim.models.ldamodel.LdaModel` (gensim4). """ - from gensim.models.lda import LdaModel dic = corpus.dictionary freqs = corpus.frequencies - model = LdaModel(list(corpus), **kwargs) + from importlib.metadata import version - doc2topic = corpus2dense(model[freqs], num_topics) + gensim_mayor = int(version("gensim").split(".")[0]) + + if gensim_mayor == 3: + from gensim.models.lda import LdaModel + + model = LdaModel(list(corpus), **kwargs) + elif gensim_mayor == 4: + from gensim.models.ldamodel import LdaModel + + model = LdaModel(freqs, id2word=dic, num_topics=num_topics, **kwargs) + else: + from sys import exit + + exit("LdaModel could not be imported from gensim 3 or 4.") + + doc2topic = corpus2dense(model[freqs], num_topics).T + topic2token = model.get_topics() + + return TopicModel(dic, doc2topic, topic2token) + + +def train_elda_model(corpus: Corpus, num_topics, num_models=4, seed=0, **kwargs) -> TopicModel: + """Train a topic model using ensemble LDA. + + :param num_topics: The number of topics to train. + :param num_models: The number of models to train. + :param seed: The seed used for random number generation. + :param kwargs: Arguments passed to `gensim.models.ensemblelda.EnsembleLda` (gensim4). + """ + + from importlib.metadata import version + + gensim_mayor = int(version("gensim").split(".")[0]) + + if gensim_mayor <= 3: + from sys import exit + + exit("EnsembleLda requires at least gensim 4.") + + dic = corpus.dictionary + freqs = corpus.frequencies + + from gensim.models.ensemblelda import EnsembleLda + + model = EnsembleLda( + topic_model_class="ldamulticore", + corpus=freqs, + id2word=dic, + num_topics=num_topics, + num_models=num_models, + **kwargs + ) + + doc2topic = corpus2dense(model[freqs], num_topics).T topic2token = model.get_topics() return TopicModel(dic, doc2topic, topic2token) diff --git a/litstudy/sources/scopus_csv.py b/litstudy/sources/scopus_csv.py index 1ee3fad..2679f15 100644 --- a/litstudy/sources/scopus_csv.py +++ b/litstudy/sources/scopus_csv.py @@ -1,6 +1,7 @@ """ support loading Scopus CSV export. """ + from typing import List, Optional from ..types import Document, Author, DocumentSet, DocumentIdentifier, Affiliation from ..common import robust_open