Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/domain_engine #31

Merged
merged 12 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ovos_padatious/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from .intent_container import IntentContainer
from .domain_container import DomainIntentContainer
from .match_data import MatchData

__version__ = '0.4.8' # Also change in setup.py
175 changes: 175 additions & 0 deletions ovos_padatious/domain_container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from collections import defaultdict
from typing import Dict, List, Optional
from ovos_utils.log import LOG
from ovos_padatious.intent_container import IntentContainer
from ovos_padatious.match_data import MatchData


class DomainIntentContainer:
"""
A domain-aware intent recognition engine that organizes intents and entities
into specific domains, providing flexible and hierarchical intent matching.
"""

def __init__(self, cache_dir: Optional[str] = None, disable_padaos: bool = False):
"""
Initialize the DomainIntentEngine.

Attributes:
domain_engine (IntentContainer): A top-level intent container for cross-domain calculations.
domains (Dict[str, IntentContainer]): A mapping of domain names to their respective intent containers.
training_data (Dict[str, List[str]]): A mapping of domain names to their associated training samples.
"""
self.cache_dir = cache_dir
self.disable_padaos = disable_padaos
self.domain_engine = IntentContainer(cache_dir=cache_dir,
disable_padaos=disable_padaos)
self.domains: Dict[str, IntentContainer] = {}
self.training_data: Dict[str, List[str]] = defaultdict(list)
self.must_train = True

def remove_domain(self, domain_name: str):
"""
Remove a domain and its associated intents and training data.

Args:
domain_name (str): The name of the domain to remove.
"""
if domain_name in self.training_data:
self.training_data.pop(domain_name)
if domain_name in self.domains:
self.domains.pop(domain_name)
if domain_name in self.domain_engine.intent_names:
self.domain_engine.remove_intent(domain_name)

def add_domain_intent(self, domain_name: str, intent_name: str, intent_samples: List[str]):
"""
Register an intent within a specific domain.

Args:
domain_name (str): The name of the domain.
intent_name (str): The name of the intent to register.
intent_samples (List[str]): A list of sample sentences for the intent.
"""
if domain_name not in self.domains:
self.domains[domain_name] = IntentContainer(cache_dir=self.cache_dir,
disable_padaos=self.disable_padaos)
self.domains[domain_name].add_intent(intent_name, intent_samples)
self.training_data[domain_name] += intent_samples
self.must_train = True

def remove_domain_intent(self, domain_name: str, intent_name: str):
"""
Remove a specific intent from a domain.

Args:
domain_name (str): The name of the domain.
intent_name (str): The name of the intent to remove.
"""
if domain_name in self.domains:
self.domains[domain_name].remove_intent(intent_name)

def add_domain_entity(self, domain_name: str, entity_name: str, entity_samples: List[str]):
"""
Register an entity within a specific domain.

Args:
domain_name (str): The name of the domain.
entity_name (str): The name of the entity to register.
entity_samples (List[str]): A list of sample phrases for the entity.
"""
if domain_name not in self.domains:
self.domains[domain_name] = IntentContainer(cache_dir=self.cache_dir,
disable_padaos=self.disable_padaos)
self.domains[domain_name].add_entity(entity_name, entity_samples)

def remove_domain_entity(self, domain_name: str, entity_name: str):
"""
Remove a specific entity from a domain.

Args:
domain_name (str): The name of the domain.
entity_name (str): The name of the entity to remove.
"""
if domain_name in self.domains:
self.domains[domain_name].remove_entity(entity_name)

def calc_domains(self, query: str) -> List[MatchData]:
"""
Calculate the matching domains for a query.

Args:
query (str): The input query.

Returns:
List[MatchData]: A list of MatchData objects representing matching domains.
"""
if self.must_train:
self.train()

return self.domain_engine.calc_intents(query)

def calc_domain(self, query: str) -> MatchData:
"""
Calculate the best matching domain for a query.

Args:
query (str): The input query.

Returns:
MatchData: The best matching domain.
"""
if self.must_train:
self.train()
return self.domain_engine.calc_intent(query)

def calc_intent(self, query: str, domain: Optional[str] = None) -> MatchData:
"""
Calculate the best matching intent for a query within a specific domain.

Args:
query (str): The input query.
domain (Optional[str]): The domain to limit the search to. Defaults to None.

Returns:
MatchData: The best matching intent.
"""
if self.must_train:
self.train()
domain: str = domain or self.domain_engine.calc_intent(query).name
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
if domain in self.domains:
return self.domains[domain].calc_intent(query)
return MatchData(name=None, sent=query, matches=None, conf=0.0)

def calc_intents(self, query: str, domain: Optional[str] = None, top_k_domains: int = 2) -> List[MatchData]:
"""
Calculate matching intents for a query across domains or within a specific domain.

Args:
query (str): The input query.
domain (Optional[str]): The specific domain to search in. If None, searches across top-k domains.
top_k_domains (int): The number of top domains to consider. Defaults to 2.

Returns:
List[MatchData]: A list of MatchData objects representing matching intents, sorted by confidence.
"""
if self.must_train:
self.train()
if domain:
return self.domains[domain].calc_intents(query)
matches = []
domains = self.calc_domains(query)[:top_k_domains]
for domain in domains:
if domain.name in self.domains:
matches += self.domains[domain.name].calc_intents(query)
return sorted(matches, reverse=True, key=lambda k: k.conf)

def train(self):
for domain, samples in self.training_data.items():
LOG.debug(f"Training domain: {domain}")
self.domain_engine.add_intent(domain, samples)
self.domain_engine.train()
for domain in self.domains:
LOG.debug(f"Training domain sub-intents: {domain}")
self.domains[domain].train()
self.must_train = False
9 changes: 8 additions & 1 deletion ovos_padatious/intent_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from functools import wraps
from typing import List, Dict, Any, Optional

from ovos_config.meta import get_xdg_base
from ovos_utils.log import LOG
from ovos_utils.xdg_utils import xdg_data_home

from ovos_padatious import padaos
from ovos_padatious.entity import Entity
Expand Down Expand Up @@ -54,7 +56,8 @@ class IntentContainer:
cache_dir (str): Directory for caching the neural network models and intent/entity files.
"""

def __init__(self, cache_dir: str, disable_padaos: bool = False) -> None:
def __init__(self, cache_dir: Optional[str] = None, disable_padaos: bool = False) -> None:
cache_dir = cache_dir or f"{xdg_data_home()}/{get_xdg_base()}/intent_cache"
os.makedirs(cache_dir, exist_ok=True)
self.cache_dir: str = cache_dir
self.must_train: bool = False
Expand All @@ -68,6 +71,10 @@ def __init__(self, cache_dir: str, disable_padaos: bool = False) -> None:
self.train_thread: Optional[Any] = None # deprecated
self.serialized_args: List[Dict[str, Any]] = [] # Serialized calls for training intents/entities

@property
def intent_names(self):
return self.intents.intent_names

def clear(self) -> None:
"""
Clears the current intent and entity managers and resets the container.
Expand Down
5 changes: 4 additions & 1 deletion ovos_padatious/intent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def __init__(self, cache: str, debug: bool = False):
super().__init__(Intent, cache)
self.debug = debug

@property
def intent_names(self):
return [i.name for i in self.objects + self.objects_to_train]

def calc_intents(self, query: str, entity_manager) -> List[MatchData]:
"""
Calculate matches for the given query against all registered intents.
Expand All @@ -44,7 +48,6 @@ def calc_intents(self, query: str, entity_manager) -> List[MatchData]:
List[MatchData]: A list of matches sorted by confidence.
"""
sent = tokenize(query)
matches = []

def match_intent(intent):
start_time = time.monotonic()
Expand Down
Loading
Loading