diff --git a/README.md b/README.md index e4592b5..de99423 100644 --- a/README.md +++ b/README.md @@ -20,42 +20,26 @@ - Lemmatization and Stemming - Visualization with [topicwizard](https://github.com/x-tabdeveloping/topicwizard) πŸ–ŒοΈ +## New in version 0.12.0: Seeded topic modeling -## New in version 0.11.0: Vectorizers Module - -You can now use a set of custom vectorizers for topic modeling over **phrases**, as well as **lemmata** and **stems**. +You can now specify an aspect in KeyNMF from which you want to investigate your corpus by specifying a seed phrase. ```python from turftopic import KeyNMF -from turftopic.vectorizers.spacy import NounPhraseCountVectorizer -model = KeyNMF( - n_components=10, - vectorizer=NounPhraseCountVectorizer("en_core_web_sm"), -) +model = KeyNMF(5, seed_phrase="Is the death penalty moral?") model.fit(corpus) + model.print_topics() ``` | Topic ID | Highest Ranking | | - | - | -| | ... | -| 3 | fanaticism, theism, fanatism, all fanatism, theists, strong theism, strong atheism, fanatics, precisely some theists, all theism | -| 4 | religion foundation darwin fish bumper stickers, darwin fish, atheism, 3d plastic fish, fish symbol, atheist books, atheist organizations, negative atheism, positive atheism, atheism index | -| | ... | - -Turftopic now also comes with a **Chinese vectorizer** for easier use, as well as a generalist **multilingual vectorizer**. - -```python -from turftopic.vectorizers.chinese import default_chinese_vectorizer -from turftopic.vectorizers.spacy import TokenCountVectorizer - -chinese_vectorizer = default_chinese_vectorizer() -arabic_vectorizer = TokenCountVectorizer("ar", remove_stopwords=True) -danish_vectorizer = TokenCountVectorizer("da", remove_stopwords=True) -... - -``` +| 0 | morality, moral, immoral, morals, objective, morally, animals, society, species, behavior | +| 1 | armenian, armenians, genocide, armenia, turkish, turks, soviet, massacre, azerbaijan, kurdish | +| 2 | murder, punishment, death, innocent, penalty, kill, crime, moral, criminals, executed | +| 3 | gun, guns, firearms, crime, handgun, firearm, weapons, handguns, law, criminals | +| 4 | jews, israeli, israel, god, jewish, christians, sin, christian, palestinians, christianity | ## Basics [(Documentation)](https://x-tabdeveloping.github.io/turftopic/) @@ -179,6 +163,29 @@ model.print_topics() | 3 | Storage Technologies | disk, drive, scsi, drives, disks, floppy, ide, dos, controller, boot | | | ... | +### Vectorizers Module + +You can use a set of custom vectorizers for topic modeling over **phrases**, as well as **lemmata** and **stems**. + +```python +from turftopic import KeyNMF +from turftopic.vectorizers.spacy import NounPhraseCountVectorizer + +model = KeyNMF( + n_components=10, + vectorizer=NounPhraseCountVectorizer("en_core_web_sm"), +) +model.fit(corpus) +model.print_topics() +``` + +| Topic ID | Highest Ranking | +| - | - | +| | ... | +| 3 | fanaticism, theism, fanatism, all fanatism, theists, strong theism, strong atheism, fanatics, precisely some theists, all theism | +| 4 | religion foundation darwin fish bumper stickers, darwin fish, atheism, 3d plastic fish, fish symbol, atheist books, atheist organizations, negative atheism, positive atheism, atheism index | +| | ... | + ### Visualization Turftopic does not come with built-in visualization utilities, [topicwizard](https://github.com/x-tabdeveloping/topicwizard), an interactive topic model visualization library, is compatible with all models from Turftopic. diff --git a/docs/KeyNMF.md b/docs/KeyNMF.md index f5367f3..c1455d4 100644 --- a/docs/KeyNMF.md +++ b/docs/KeyNMF.md @@ -8,20 +8,30 @@ while taking inspiration from classical matrix-decomposition approaches for extr
Schematic overview of KeyNMF
+ Here's an example of how you can fit and interpret a KeyNMF model in the easiest way. ```python from turftopic import KeyNMF -model = KeyNMF(10, top_n=6) +model = KeyNMF(10, encoder="paraphrase-MiniLM-L3-v2") model.fit(corpus) model.print_topics() ``` +!!! question "Which Embedding model should I use" + - You should probably use KeyNMF with a `paraphrase-` type embedding model. These seem to perform best in most tasks. Some examples include: + - [paraphrase-MiniLM-L3-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L3-v2) - Absolutely tiny :mouse: + - [paraphrase-mpnet-base-v2](https://huggingface.co/sentence-transformers/paraphrase-mpnet-base-v2) - High performance :star2: + - [paraphrase-multilingual-mpnet-base-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2) - Multilingual, high-performance :earth_americas: :star2: + - KeyNMF works remarkably well with static models, which are incredibly fast, even on your laptop: + - [sentence-transformers/static-retrieval-mrl-en-v1](https://huggingface.co/sentence-transformers/static-retrieval-mrl-en-v1) - Blazing Fast :zap: + - [sentence-transformers/static-similarity-mrl-multilingual-v1](https://huggingface.co/sentence-transformers/static-similarity-mrl-multilingual-v1) - Multilingual, Blazing Fast :earth_americas: :zap: + ## How does KeyNMF work? -### Keyword Extraction +#### Keyword Extraction KeyNMF discovers topics based on the importances of keywords for a given document. This is done by embedding words in a document, and then extracting the cosine similarities of documents to words using a transformer-model. @@ -78,7 +88,7 @@ keyword_matrix = model.extract_keywords(corpus) model.fit(None, keywords=keyword_matrix) ``` -### Topic Discovery +#### Topic Discovery Topics in this matrix are then discovered using Non-negative Matrix Factorization. Essentially the model tries to discover underlying dimensions/factors along which most of the variance in term importance @@ -94,70 +104,89 @@ can be explained. You can fit KeyNMF on the raw corpus, with precomputed embeddings or with precomputed keywords. -```python -# Fitting just on the corpus -model.fit(corpus) -# Fitting with precomputed embeddings -from sentence_transformers import SentenceTransformer +=== "Fitting on a corpus" + ```python + model.fit(corpus) + ``` -trf = SentenceTransformer("all-MiniLM-L6-v2") -embeddings = trf.encode(corpus) +=== "Pre-computed embeddings" + ```python + from sentence_transformers import SentenceTransformer -model = KeyNMF(10, encoder=trf) -model.fit(corpus, embeddings=embeddings) + trf = SentenceTransformer("all-MiniLM-L6-v2") + embeddings = trf.encode(corpus) -# Fitting with precomputed keyword matrix -keyword_matrix = model.extract_keywords(corpus) -model.fit(None, keywords=keyword_matrix) -``` + model = KeyNMF(10, encoder=trf) + model.fit(corpus, embeddings=embeddings) + ``` +=== "Pre-computed keyword matrix" + ```python + keyword_matrix = model.extract_keywords(corpus) + model.fit(None, keywords=keyword_matrix) + ``` -### Asymmetric and Instruction-tuned Embedding Models +## Seeded Topic Modeling -Some embedding models can be used together with prompting, or encode queries and passages differently. -This is important for KeyNMF, as it is explicitly based on keyword retrieval, and its performance can be substantially enhanced by using asymmetric or prompted embeddings. -Microsoft's E5 models are, for instance, all prompted by default, and it would be detrimental to performance not to do so yourself. +When investigating a set of documents, you might already have an idea about what aspects you would like to explore. +In KeyNMF, you can describe this aspect, from which you want to investigate your corpus, using a free-text seed-phrase, +which will then be used to only extract topics, which are relevant to your research question. -In these cases, you're better off NOT passing a string to Turftopic models, but explicitly loading the model using `sentence-transformers`. +??? info "How is this done?" -Here's an example of using instruct models for keyword retrieval with KeyNMF. -In this case, documents will serve as the queries and words as the passages: + KeyNMF encodes the seed phrase into a seed-embedding. + Word importance scores in a document get weighted by their similarity to the seed-embedding. + + - Embed seed-phrase into a seed-embedding: $s$ + - When extracting keywords from a document: + 1. Let $x_d$ be the document's embedding produced with the encoder model. + 2. Let the document's relevance be $r_d = \text{sim}(d,w)$ + 3. For each word $w$: + 1. Let the word's importance in the keyword matrix be: $\text{sim}(d, w) \cdot r_d$ if $r_d > 0$, otherwise $0$ ```python from turftopic import KeyNMF -from sentence_transformers import SentenceTransformer -encoder = SentenceTransformer( - "intfloat/multilingual-e5-large-instruct", - prompts={ - "query": "Instruct: Retrieve relevant keywords from the given document. Query: " - "passage": "Passage: " - }, - # Make sure to set default prompt to query! - default_prompt_name="query", -) -model = KeyNMF(10, encoder=encoder) +model = KeyNMF(5, seed_phrase="") +model.fit(corpus) + +model.print_topics() ``` -And a regular, asymmetric example: -```python -encoder = SentenceTransformer( - "intfloat/e5-large-v2", - prompts={ - "query": "query: " - "passage": "passage: " - }, - # Make sure to set default prompt to query! - default_prompt_name="query", -) -model = KeyNMF(10, encoder=encoder) -``` +=== "`'Is the death penalty moral?'`" + + | Topic ID | Highest Ranking | + | - | - | + | 0 | morality, moral, immoral, morals, objective, morally, animals, society, species, behavior | + | 1 | armenian, armenians, genocide, armenia, turkish, turks, soviet, massacre, azerbaijan, kurdish | + | 2 | murder, punishment, death, innocent, penalty, kill, crime, moral, criminals, executed | + | 3 | gun, guns, firearms, crime, handgun, firearm, weapons, handguns, law, criminals | + | 4 | jews, israeli, israel, god, jewish, christians, sin, christian, palestinians, christianity | + +=== "`'Evidence for the existence of god'`" + + | Topic ID | Highest Ranking | + | - | - | + | 0 | atheist, atheists, religion, religious, theists, beliefs, christianity, christian, religions, agnostic | + | 1 | bible, christians, christian, christianity, church, scripture, religion, jesus, faith, biblical | + | 2 | god, existence, exist, exists, universe, creation, argument, creator, believe, life | + | 3 | believe, faith, belief, evidence, blindly, believing, gods, believed, beliefs, convince | + | 4 | atheism, atheists, agnosticism, belief, arguments, believe, existence, alt, believing, argument | + +=== "`'Operating system kernels'`" + + | Topic ID | Highest Ranking | + | - | - | + | 0 | windows, dos, os, microsoft, ms, apps, pc, nt, file, shareware | + | 1 | ram, motherboard, card, monitor, memory, cpu, vga, mhz, bios, intel | + | 2 | unix, os, linux, intel, systems, programming, applications, compiler, software, platform | + | 3 | disk, scsi, disks, drive, floppy, drives, dos, controller, cd, boot | + | 4 | software, mac, hardware, ibm, graphics, apple, computer, pc, modem, program | -Setting the default prompt to `query` is especially important, when you are precomputing embeddings, as `query` should always be your default prompt to embed documents with. -### Dynamic Topic Modeling +## Dynamic Topic Modeling KeyNMF is also capable of modeling topics over time. This happens by fitting a KeyNMF model first on the entire corpus, then @@ -229,7 +258,48 @@ model.plot_topics_over_time()
Topics over time in a Dynamic KeyNMF model.
-### Online Topic Modeling +## Hierarchical Topic Modeling + +When you suspect that subtopics might be present in the topics you find with the model, KeyNMF can be used to discover topics further down the hierarchy. + +This is done by utilising a special case of **weighted NMF**, where documents are weighted by how high they score on the parent topic. + +??? info "Click to see formula" + 1. Decompose keyword matrix $M \approx WH$ + 2. To find subtopics in topic $j$, define document weights $w$ as the $j$th column of $W$. + 3. Estimate subcomponents with **wNMF** $M \approx \mathring{W} \mathring{H}$ with document weight $w$ + 1. Initialise $\mathring{H}$ and $\mathring{W}$ randomly. + 2. Perform multiplicative updates until convergence.
+ $\mathring{W}^T = \mathring{W}^T \odot \frac{\mathring{H} \cdot (M^T \odot w)}{\mathring{H} \cdot \mathring{H}^T \cdot (\mathring{W}^T \odot w)}$
+ $\mathring{H}^T = \mathring{H}^T \odot \frac{ (M^T \odot w)\cdot \mathring{W}}{\mathring{H}^T \cdot (\mathring{W}^T \odot w) \cdot \mathring{W}}$ + 4. To sufficiently differentiate the subcomponents from each other a pseudo-c-tf-idf weighting scheme is applied to $\mathring{H}$: + 1. $\mathring{H} = \mathring{H}_{ij} \odot ln(1 + \frac{A}{1+\sum_k \mathring{H}_{kj}})$, where $A$ is the average of all elements in $\mathring{H}$ + +To create a hierarchical model, you can use the `hierarchy` property of the model. + +```python +# This divides each of the topics in the model to 3 subtopics. +model.hierarchy.divide_children(n_subtopics=3) +print(model.hierarchy) +``` + +
+ +Root
+β”œβ”€β”€ 0: windows, dos, os, disk, card, drivers, file, pc, files, microsoft
+β”‚ β”œβ”€β”€ 0.0: dos, file, disk, files, program, windows, disks, shareware, norton, memory
+β”‚ β”œβ”€β”€ 0.1: os, unix, windows, microsoft, apps, nt, ibm, ms, os2, platform
+β”‚ └── 0.2: card, drivers, monitor, driver, vga, ram, motherboard, cards, graphics, ati
+└── 1: atheism, atheist, atheists, religion, christians, religious, belief, christian, god, beliefs
+. β”œβ”€β”€ 1.0: atheism, alt, newsgroup, reading, faq, islam, questions, read, newsgroups, readers
+. β”œβ”€β”€ 1.1: atheists, atheist, belief, theists, beliefs, religious, religion, agnostic, gods, religions
+. └── 1.2: morality, bible, christian, christians, moral, christianity, biblical, immoral, god, religion
+
+
+ +For a detailed tutorial on hierarchical modeling click [here](hierarchical.md). + +## Online Topic Modeling KeyNMF can also be fitted in an online manner. This is done by fitting NMF with batches of data instead of the whole dataset at once. @@ -326,7 +396,7 @@ for epoch in range(5): model.partial_fit(keywords=keyword_batch) ``` -#### Dynamic Online Topic Modeling +### Dynamic Online Topic Modeling KeyNMF can be online fitted in a dynamic manner as well. This is useful when you have large corpora of text over time, or when you want to fit the model on future information flowing in @@ -354,46 +424,49 @@ for batch in batched(zip(corpus, timestamps)): model.partial_fit_dynamic(text_batch, timestamps=ts_batch, bins=bins) ``` -### Hierarchical Topic Modeling +## Asymmetric and Instruction-tuned Embedding Models -When you suspect that subtopics might be present in the topics you find with the model, KeyNMF can be used to discover topics further down the hierarchy. - -This is done by utilising a special case of **weighted NMF**, where documents are weighted by how high they score on the parent topic. +Some embedding models can be used together with prompting, or encode queries and passages differently. +This is important for KeyNMF, as it is explicitly based on keyword retrieval, and its performance can be substantially enhanced by using asymmetric or prompted embeddings. +Microsoft's E5 models are, for instance, all prompted by default, and it would be detrimental to performance not to do so yourself. -??? info "Click to see formula" - 1. Decompose keyword matrix $M \approx WH$ - 2. To find subtopics in topic $j$, define document weights $w$ as the $j$th column of $W$. - 3. Estimate subcomponents with **wNMF** $M \approx \mathring{W} \mathring{H}$ with document weight $w$ - 1. Initialise $\mathring{H}$ and $\mathring{W}$ randomly. - 2. Perform multiplicative updates until convergence.
- $\mathring{W}^T = \mathring{W}^T \odot \frac{\mathring{H} \cdot (M^T \odot w)}{\mathring{H} \cdot \mathring{H}^T \cdot (\mathring{W}^T \odot w)}$
- $\mathring{H}^T = \mathring{H}^T \odot \frac{ (M^T \odot w)\cdot \mathring{W}}{\mathring{H}^T \cdot (\mathring{W}^T \odot w) \cdot \mathring{W}}$ - 4. To sufficiently differentiate the subcomponents from each other a pseudo-c-tf-idf weighting scheme is applied to $\mathring{H}$: - 1. $\mathring{H} = \mathring{H}_{ij} \odot ln(1 + \frac{A}{1+\sum_k \mathring{H}_{kj}})$, where $A$ is the average of all elements in $\mathring{H}$ +In these cases, you're better off NOT passing a string to Turftopic models, but explicitly loading the model using `sentence-transformers`. -To create a hierarchical model, you can use the `hierarchy` property of the model. +Here's an example of using instruct models for keyword retrieval with KeyNMF. +In this case, documents will serve as the queries and words as the passages: ```python -# This divides each of the topics in the model to 3 subtopics. -model.hierarchy.divide_children(n_subtopics=3) -print(model.hierarchy) +from turftopic import KeyNMF +from sentence_transformers import SentenceTransformer + +encoder = SentenceTransformer( + "intfloat/multilingual-e5-large-instruct", + prompts={ + "query": "Instruct: Retrieve relevant keywords from the given document. Query: " + "passage": "Passage: " + }, + # Make sure to set default prompt to query! + default_prompt_name="query", +) +model = KeyNMF(10, encoder=encoder) ``` -
- -Root
-β”œβ”€β”€ 0: windows, dos, os, disk, card, drivers, file, pc, files, microsoft
-β”‚ β”œβ”€β”€ 0.0: dos, file, disk, files, program, windows, disks, shareware, norton, memory
-β”‚ β”œβ”€β”€ 0.1: os, unix, windows, microsoft, apps, nt, ibm, ms, os2, platform
-β”‚ └── 0.2: card, drivers, monitor, driver, vga, ram, motherboard, cards, graphics, ati
-└── 1: atheism, atheist, atheists, religion, christians, religious, belief, christian, god, beliefs
-. β”œβ”€β”€ 1.0: atheism, alt, newsgroup, reading, faq, islam, questions, read, newsgroups, readers
-. β”œβ”€β”€ 1.1: atheists, atheist, belief, theists, beliefs, religious, religion, agnostic, gods, religions
-. └── 1.2: morality, bible, christian, christians, moral, christianity, biblical, immoral, god, religion
-
-
+And a regular, asymmetric example: -For a detailed tutorial on hierarchical modeling click [here](hierarchical.md). +```python +encoder = SentenceTransformer( + "intfloat/e5-large-v2", + prompts={ + "query": "query: " + "passage": "passage: " + }, + # Make sure to set default prompt to query! + default_prompt_name="query", +) +model = KeyNMF(10, encoder=encoder) +``` + +Setting the default prompt to `query` is especially important, when you are precomputing embeddings, as `query` should always be your default prompt to embed documents with. ## API Reference diff --git a/docs/images/nmf_explanation.svg b/docs/images/nmf_explanation.svg new file mode 100644 index 0000000..45321a1 --- /dev/null +++ b/docs/images/nmf_explanation.svg @@ -0,0 +1,772 @@ + + + + + + + + + + + + + + + + + doc0 + doc1 + doc2 + ... + "dog" + "cat" + etc. + 0 + 3 + 2 + 4 + 1 + 5 + + + + + + + + + + + + + + + + + 0 + 3 + 2 + 4 + 1 + 5 + = + Doc-Term Matrix (X) + + + + + + + + + doc0 + doc1 + doc2 + ... + topic0 + topic1 + etc. + 0.05 + 0.69 + 0.72 + 0.96 + 0.02 + 0.11 + Doc-Topic Matrix (W) + Topic-Term Matrix (H) + + + + + + + + + topic0 + topic1 + topic2 + ... + "dog" + "cat" + etc. + 1.7 + 2.5 + 3.4 + 0.5 + 0.01 + 0 + + + diff --git a/docs/seeded.md b/docs/seeded.md new file mode 100644 index 0000000..c175448 --- /dev/null +++ b/docs/seeded.md @@ -0,0 +1,59 @@ +# Seeded Topic Modeling + +When investigating a set of documents, you might already have an idea about what aspects you would like to explore. +Some models are able to account for this by taking seed phrases or words. +This is currently only possible with KeyNMF in Turftopic, but will likely be extended in the future. + +In [KeyNMF](../keynmf.md), you can describe the aspect, from which you want to investigate your corpus, using a free-text seed-phrase, +which will then be used to only extract topics, which are relevant to your research question. + +In this example we investigate the 20Newsgroups corpus from three different aspects: + +```python +from sklearn.datasets import fetch_20newsgroups + +from turftopic import KeyNMF + +corpus = fetch_20newsgroups( + subset="all", + remove=("headers", "footers", "quotes"), +).data + +model = KeyNMF(5, seed_phrase="") +model.fit(corpus) + +model.print_topics() +``` + + +=== "`'Is the death penalty moral?'`" + + | Topic ID | Highest Ranking | + | - | - | + | 0 | morality, moral, immoral, morals, objective, morally, animals, society, species, behavior | + | 1 | armenian, armenians, genocide, armenia, turkish, turks, soviet, massacre, azerbaijan, kurdish | + | 2 | murder, punishment, death, innocent, penalty, kill, crime, moral, criminals, executed | + | 3 | gun, guns, firearms, crime, handgun, firearm, weapons, handguns, law, criminals | + | 4 | jews, israeli, israel, god, jewish, christians, sin, christian, palestinians, christianity | + +=== "`'Evidence for the existence of god'`" + + | Topic ID | Highest Ranking | + | - | - | + | 0 | atheist, atheists, religion, religious, theists, beliefs, christianity, christian, religions, agnostic | + | 1 | bible, christians, christian, christianity, church, scripture, religion, jesus, faith, biblical | + | 2 | god, existence, exist, exists, universe, creation, argument, creator, believe, life | + | 3 | believe, faith, belief, evidence, blindly, believing, gods, believed, beliefs, convince | + | 4 | atheism, atheists, agnosticism, belief, arguments, believe, existence, alt, believing, argument | + +=== "`'Operating system kernels'`" + + | Topic ID | Highest Ranking | + | - | - | + | 0 | windows, dos, os, microsoft, ms, apps, pc, nt, file, shareware | + | 1 | ram, motherboard, card, monitor, memory, cpu, vga, mhz, bios, intel | + | 2 | unix, os, linux, intel, systems, programming, applications, compiler, software, platform | + | 3 | disk, scsi, disks, drive, floppy, drives, dos, controller, cd, boot | + | 4 | software, mac, hardware, ibm, graphics, apple, computer, pc, modem, program | + + diff --git a/mkdocs.yml b/mkdocs.yml index c5ac574..373d62a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -8,6 +8,7 @@ nav: - Interpreting and Visualizing Models: model_interpretation.md - Modifying and Finetuning Models: finetuning.md - Saving and Loading Models: persistence.md + - Seeded Topic Modeling: seeded.md - Dynamic Topic Modeling: dynamic.md - Online Topic Modeling: online.md - Hierarchical Topic Modeling: hierarchical.md diff --git a/pyproject.toml b/pyproject.toml index ac15535..1deed16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ line-length=79 [tool.poetry] name = "turftopic" -version = "0.11.0" +version = "0.12.0" description = "Topic modeling with contextual representations from sentence transformers." authors = ["MΓ‘rton Kardos "] license = "MIT" diff --git a/turftopic/models/_keynmf.py b/turftopic/models/_keynmf.py index 92518fa..9a9d600 100644 --- a/turftopic/models/_keynmf.py +++ b/turftopic/models/_keynmf.py @@ -120,6 +120,8 @@ def batch_extract_keywords( self, documents: list[str], embeddings: Optional[np.ndarray] = None, + seed_embedding: Optional[np.ndarray] = None, + fitting: bool = True, ) -> list[dict[str, float]]: if not len(documents): return [] @@ -135,13 +137,25 @@ def batch_extract_keywords( "Number of documents doesn't match number of embeddings." ) keywords = [] - vectorizer = clone(self.vectorizer) - document_term_matrix = vectorizer.fit_transform(documents) - batch_vocab = vectorizer.get_feature_names_out() + if fitting: + document_term_matrix = self.vectorizer.fit_transform(documents) + else: + document_term_matrix = self.vectorizer.transform(documents) + batch_vocab = self.vectorizer.get_feature_names_out() new_terms = list(set(batch_vocab) - set(self.key_to_index.keys())) if len(new_terms): self._add_terms(new_terms) total = embeddings.shape[0] + # Relevance based on similarity to seed embedding + document_relevance = None + if seed_embedding is not None: + if self.metric == "cosine": + document_relevance = cosine_similarity( + [seed_embedding], embeddings + )[0] + else: + document_relevance = np.dot(embeddings, seed_embedding) + document_relevance[document_relevance < 0] = 0 for i in range(total): terms = document_term_matrix[i, :].todense() embedding = embeddings[i].reshape(1, -1) @@ -162,14 +176,13 @@ def batch_extract_keywords( ) ) if self.metric == "cosine": - sim = cosine_similarity(embedding, word_embeddings).astype( - np.float64 - ) + sim = cosine_similarity(embedding, word_embeddings) sim = np.ravel(sim) else: - sim = np.dot(word_embeddings, embedding[0]).T.astype( - np.float64 - ) + sim = np.dot(word_embeddings, embedding[0]).T + # If a seed is specified, we multiply by the document's relevance + if document_relevance is not None: + sim = document_relevance[i] * sim kth = min(self.top_n, len(sim) - 1) top = np.argpartition(-sim, kth)[:kth] top_words = batch_vocab[important_terms][top] diff --git a/turftopic/models/keynmf.py b/turftopic/models/keynmf.py index 4b7b5d2..a8e11d7 100644 --- a/turftopic/models/keynmf.py +++ b/turftopic/models/keynmf.py @@ -49,6 +49,10 @@ class KeyNMF(ContextualModel, DynamicTopicModel): Random state to use so that results are exactly reproducible. metric: "cosine" or "dot", default "cosine" Similarity metric to use for keyword extraction. + seed_phrase: str, default None + Describes an aspect of the corpus that the model should explore. + It can be a free-text query, such as + "Christian Denominations: Protestantism and Catholicism" """ def __init__( @@ -61,6 +65,7 @@ def __init__( top_n: int = 25, random_state: Optional[int] = None, metric: Literal["cosine", "dot"] = "cosine", + seed_phrase: Optional[str] = None, ): self.random_state = random_state self.n_components = n_components @@ -85,11 +90,16 @@ def __init__( encoder=self.encoder_, metric=self.metric, ) + self.seed_phrase = seed_phrase + self.seed_embedding = None + if self.seed_phrase is not None: + self.seed_embedding = self.encoder_.encode([self.seed_phrase])[0] def extract_keywords( self, batch_or_document: Union[str, list[str]], embeddings: Optional[np.ndarray] = None, + fitting: bool = True, ) -> list[dict[str, float]]: """Extracts keywords from a document or a batch of documents. @@ -103,7 +113,10 @@ def extract_keywords( if isinstance(batch_or_document, str): batch_or_document = [batch_or_document] return self.extractor.batch_extract_keywords( - batch_or_document, embeddings=embeddings + batch_or_document, + embeddings=embeddings, + seed_embedding=self.seed_embedding, + fitting=fitting, ) def vectorize( @@ -249,7 +262,9 @@ def transform( ) if keywords is None: keywords = self.extract_keywords( - list(raw_documents), embeddings=embeddings + list(raw_documents), + embeddings=embeddings, + fitting=False, ) return self.model.transform(keywords)