Skip to content

Commit

Permalink
feat:disambiguation and summarization (#88)
Browse files Browse the repository at this point in the history
* feat:disambiguation and summarization

improve answers by using BM25 to select best sentences from wikipedia

* bump python
  • Loading branch information
JarbasAl authored Jan 5, 2025
1 parent 27ac28a commit 29eb94e
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 52 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ on:
jobs:
build_tests:
strategy:
max-parallel: 2
max-parallel: 3
matrix:
python-version: [ 3.7, 3.8, 3.9, "3.10" ]
python-version: [3.9, "3.10", "3.11"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v1
with:
python-version: 3.8
python-version: "3.11"
- name: Install Build Tools
run: |
python -m pip install build wheel
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ on:
jobs:
unit_tests:
strategy:
max-parallel: 2
max-parallel: 3
matrix:
python-version: [ 3.7, 3.8, 3.9, "3.10" ]
python-version: [3.9, "3.10", "3.11" ]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
Expand Down
188 changes: 145 additions & 43 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,45 +12,83 @@
import concurrent.futures
import os.path
import re
from typing import Optional, Tuple
from functools import lru_cache
from typing import Optional, Tuple, List, Dict

import requests
from langcodes import closest_supported_match
from padacioso import IntentContainer
from quebra_frases import sentence_tokenize

from ovos_bm25_solver import BM25MultipleChoiceSolver
from ovos_bus_client.session import SessionManager, Session
from ovos_plugin_manager.templates.solvers import QuestionSolver
from ovos_utils import classproperty, flatten_list
from ovos_utils.bracket_expansion import expand_template
from ovos_utils.gui import can_use_gui
from ovos_utils.log import LOG
from ovos_utils.parse import fuzzy_match, MatchStrategy
from ovos_utils.process_utils import RuntimeRequirements
from ovos_workshop.decorators import intent_handler, common_query
from ovos_workshop.intents import IntentBuilder
from ovos_workshop.skills.ovos import OVOSSkill
from padacioso import IntentContainer
from quebra_frases import sentence_tokenize


@lru_cache(maxsize=128)
def rm_parentheses(text: str) -> str:
"""helper to remove the text between paranthesis in a wikipedia summary,
makes the text more natural and speakable"""
"""
Remove text enclosed in parentheses from the given string.
Args:
text (str): Input string.
Returns:
str: String with parentheses and their contents removed.
"""
return re.sub(r"\((.*?)\)", "", text).replace(" ", " ")


class WikipediaSolver(QuestionSolver):
"""
A solver for answering questions using Wikipedia search and summaries.
Attributes:
priority (int): Priority of the solver.
enable_tx (bool): Transmission enable status.
kw_matchers (Dict[str, IntentContainer]): Registered keyword extractors by language.
"""
priority = 40
enable_tx = False
kw_matchers = {}
kw_matchers: Dict[str, IntentContainer] = {}

# Utils to extract keywords from text
@classmethod
def register_kw_extractors(cls, samples: list, lang: str):
def register_kw_extractors(cls, samples: List[str], lang: str) -> None:
"""
Register keyword extractors for a given language.
Args:
samples (List[str]): List of sample utterances.
lang (str): Language code.
"""
lang = lang.split("-")[0]
if lang not in cls.kw_matchers:
cls.kw_matchers[lang] = IntentContainer()
cls.kw_matchers[lang].add_intent("question", samples)

@classmethod
def extract_keyword(cls, utterance: str, lang: str):
@lru_cache(maxsize=128)
def extract_keyword(cls, utterance: str, lang: str) -> Optional[str]:
"""
Extract a keyword from an utterance for a given language.
Args:
utterance (str): Input text.
lang (str): Language code.
Returns:
Optional[str]: Extracted keyword or None.
"""
lang = lang.split("-")[0]
if lang not in cls.kw_matchers:
return None
Expand All @@ -63,31 +101,83 @@ def extract_keyword(cls, utterance: str, lang: str):
LOG.debug(f"Could not extract search keyword for '{lang}' from '{utterance}'")
return kw

def get_page_data(self, pid: str, lang: str):
"""Fetch detailed data for a single Wikipedia page."""
@staticmethod
@lru_cache(maxsize=128)
def get_page_data(pid: str, lang: str) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""
Fetch detailed data for a specific Wikipedia page.
Args:
pid (str): Page ID.
lang (str): Language code.
Returns:
Tuple[Optional[str], Optional[str], Optional[str]]: Page title, summary, and image URL.
"""
url = (
f"https://{lang}.wikipedia.org/w/api.php?format=json&action=query&"
f"prop=extracts|pageimages&exintro&explaintext&redirects=1&pageids={pid}"
)
try:
disambiguation_indicators = ["may refer to:", "refers to:"]
response = requests.get(url, timeout=5).json()
page = response["query"]["pages"][pid]
summary = rm_parentheses(page.get("extract", ""))
if "commonly refers to:" in summary:
return None, None, None # disambiguation list page
if any(i in summary for i in disambiguation_indicators):
return None, None, None # Disambiguation list page
img = None
if "thumbnail" in page:
thumbnail = page["thumbnail"]["source"]
parts = thumbnail.split("/")[:-1]
img = "/".join(part for part in parts if part != "thumb")
ans = flatten_list([sentence_tokenize(s) for s in summary.split("\n")])

return page["title"], ans, img
return page["title"], summary, img
except Exception as e:
LOG.error(f"Error fetching page data for PID {pid}: {e}")
return None, None, None

def get_data(self, query: str, lang: Optional[str] = None, units: Optional[str] = None):
@staticmethod
@lru_cache(maxsize=128)
def summarize(query: str, summary: str) -> str:
"""
Summarize a text using a query for context.
Args:
query (str): User query.
summary (str): Wikipedia summary.
Returns:
str: Top-ranked summarized text.
"""
top_k = 3
sentences = sentence_tokenize(summary)
ranked = BM25MultipleChoiceSolver().rerank(query, sentences)[:top_k]
return " ".join([s[1] for s in ranked])

@staticmethod
@lru_cache(maxsize=128)
def score_page(query: str, title: str, summary: str, idx: int) -> float:
"""
Score a Wikipedia page based on its relevance to a query.
Args:
query (str): User query.
title (str): Page title.
summary (str): Page summary.
idx (int): Index in the original search result order.
Returns:
float: Relevance score.
"""
page_mod = 1 - (idx * 0.05) # Favor original order returned by Wikipedia
title_score = max(
fuzzy_match(query, title, MatchStrategy.DAMERAU_LEVENSHTEIN_SIMILARITY),
fuzzy_match(query, rm_parentheses(title), MatchStrategy.DAMERAU_LEVENSHTEIN_SIMILARITY)
)
summary_score = fuzzy_match(summary, title, MatchStrategy.TOKEN_SET_RATIO)
return title_score * summary_score * page_mod

def get_data(self, query: str, lang: Optional[str] = None, units: Optional[str] = None,
skip_disambiguation: bool = False):
"""Fetch Wikipedia search results and detailed data concurrently."""
LOG.debug(f"WikiSolver query: {query}")
lang = (lang or self.default_lang).split("-")[0]
Expand All @@ -109,11 +199,13 @@ def get_data(self, query: str, lang: Optional[str] = None, units: Optional[str]
return self.get_data(fallback_query, lang=lang, units=units)
return {}

LOG.debug(f"Matched {len(search_results)} Wikipedia pages")
top_k = 3 if not skip_disambiguation else 1
LOG.debug(f"Matched {len(search_results)} Wikipedia pages, using top {top_k}")
search_results = search_results[:top_k]

# Prepare for parallel fetch and maintain original order
summaries = [None] * len(search_results) # List to hold results in original order
with concurrent.futures.ThreadPoolExecutor() as executor:
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
future_to_idx = {
executor.submit(self.get_page_data, str(r["pageid"]), lang): idx
for idx, r in enumerate(search_results)
Expand All @@ -126,41 +218,48 @@ def get_data(self, query: str, lang: Optional[str] = None, units: Optional[str]
if title and ans:
summaries[idx] = (title, ans, img)

# Filter out None entries and sort based on original order
summaries = [entry for entry in summaries if entry is not None]

if summaries:
if len(summaries) == 1:
return {"title": summaries[0][0],
"short_answer": summaries[0][1][0],
"summary": "\n".join(summaries[0][1]),
"img": summaries[0][2]}

final_ans = "\n".join([sentences[0] for title, sentences, _ in summaries[:3]])
final_sum = "\n\n".join([title + " - " + ".\n".join(sents)
for title, sents, img in summaries])
return {"title": query,
"short_answer": final_ans,
"summary": final_sum,
"img": summaries[0][2]}
summaries = [s for s in summaries if s is not None]
if not summaries:
return {}

return {}
reranked = []
shorts = []
for idx, (title, summary, img) in enumerate(summaries):
short = self.summarize(query, summary)
score = self.score_page(query, title, short, idx)
reranked.append((idx, score))
shorts.append(short)

reranked = sorted(reranked, key=lambda x: x[1], reverse=True)
selected = reranked[0][0]

return {
"title": summaries[selected][0],
"short_answer": shorts[selected],
"summary": summaries[selected][1],
"img": summaries[selected][2],
}

def get_spoken_answer(self, query: str,
lang: Optional[str] = None,
units: Optional[str] = None):
data = self.get_data(query, lang=lang, units=units)
units: Optional[str] = None,
skip_disambiguation: bool = False):
data = self.get_data(query, lang=lang, units=units,
skip_disambiguation=skip_disambiguation)
return data.get("short_answer", "")

def get_image(self, query: str,
lang: Optional[str] = None,
units: Optional[str] = None):
data = self.get_data(query, lang=lang, units=units)
units: Optional[str] = None,
skip_disambiguation: bool = True):
data = self.get_data(query, lang=lang, units=units,
skip_disambiguation=skip_disambiguation)
return data.get("img", "")

def get_expanded_answer(self, query: str,
lang: Optional[str] = None,
units: Optional[str] = None):
units: Optional[str] = None,
skip_disambiguation: bool = False):
"""
return a list of ordered steps to expand the answer, eg, "tell me more"
{
Expand All @@ -169,7 +268,8 @@ def get_expanded_answer(self, query: str,
"img": "optional/path/or/url
}
"""
data = self.get_data(query, lang=lang, units=units)
data = self.get_data(query, lang=lang, units=units,
skip_disambiguation=skip_disambiguation)
ans = flatten_list([sentence_tokenize(s) for s in data["summary"].split("\n")])
steps = [{
"title": data.get("title", query).title(),
Expand All @@ -193,7 +293,8 @@ def register_kw_xtract(self):

lang2 = closest_supported_match(lang, supported, 10)
if not lang2:
LOG.warning(f"'{self.root_dir}/locale/{lang}' directory not found! wikipedia will be disabled for '{lang}'")
LOG.warning(
f"'{self.root_dir}/locale/{lang}' directory not found! wikipedia will be disabled for '{lang}'")
continue

filename = f"{self.root_dir}/locale/{lang2}/query.intent"
Expand Down Expand Up @@ -368,6 +469,7 @@ def stop_session(self, sess):


if __name__ == "__main__":
LOG.set_level("ERROR")
from ovos_utils.fakebus import FakeBus

s = WikipediaSkill(bus=FakeBus(), skill_id="wiki.skill")
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ ovos-utils>=0.0.38,<1.0.0
ovos_workshop>=3.3.2,<4.0.0
ovos-plugin-manager>=0.0.26,<1.0.0
ovos-bus-client>=1.0.1
ovos-solver-bm25-plugin
2 changes: 1 addition & 1 deletion test/unittests/test_continuous_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from unittest import skip
from unittest.mock import Mock

from ovos_utils.messagebus import FakeBus, Message
from ovos_utils.fakebus import FakeBus, Message
from skill_ovos_wikipedia import WikipediaSkill


Expand Down
2 changes: 1 addition & 1 deletion test/unittests/test_lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from unittest import skip
from unittest.mock import Mock

from ovos_utils.messagebus import FakeBus, Message
from ovos_utils.fakebus import FakeBus, Message
from skill_ovos_wikipedia import WikipediaSkill


Expand Down
2 changes: 1 addition & 1 deletion test/unittests/test_skill.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import unittest

import requests
from ovos_utils.messagebus import FakeBus
from ovos_utils.fakebus import FakeBus
from skill_ovos_wikipedia import WikipediaSkill
from ovos_workshop.skills.common_query_skill import CommonQuerySkill

Expand Down
2 changes: 1 addition & 1 deletion test/unittests/test_skill_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from ovos_workshop.skill_launcher import PluginSkillLoader, SkillLoader
from ovos_plugin_manager.skills import find_skill_plugins
from ovos_utils.messagebus import FakeBus
from ovos_utils.fakebus import FakeBus
from skill_ovos_wikipedia import WikipediaSkill


Expand Down

0 comments on commit 29eb94e

Please sign in to comment.