diff --git a/ovos_plugin_manager/pipeline.py b/ovos_plugin_manager/pipeline.py index 56bcc01c..ec56555a 100644 --- a/ovos_plugin_manager/pipeline.py +++ b/ovos_plugin_manager/pipeline.py @@ -1,8 +1,15 @@ -from ovos_plugin_manager.templates.pipeline import PipelinePlugin +from typing import List, Optional, Tuple, Callable, Union, Dict, Type + +from ovos_bus_client.client import MessageBusClient +from ovos_config import Configuration +from ovos_utils.fakebus import FakeBus +from ovos_utils.log import log_deprecation + +from ovos_plugin_manager.templates.pipeline import ConfidenceMatcherPipeline, PipelineStageMatcher, PipelinePlugin from ovos_plugin_manager.utils import PluginTypes -def find_pipeline_plugins() -> dict: +def find_pipeline_plugins() -> Dict[str, Type[PipelinePlugin]]: """ Find all installed plugins @return: dict plugin names to entrypoints @@ -11,11 +18,142 @@ def find_pipeline_plugins() -> dict: return find_plugins(PluginTypes.PIPELINE) -def load_pipeline_plugin(module_name: str) -> type(PipelinePlugin): +def load_pipeline_plugin(module_name: str) -> Type[PipelinePlugin]: """ - Get an uninstantiated class for the requested module_name - @param module_name: Plugin entrypoint name to load - @return: Uninstantiated class + Load and return an uninstantiated class for the specified pipeline plugin. + + @param module_name: The name of the plugin to load. + @return: The uninstantiated plugin class. """ from ovos_plugin_manager.utils import load_plugin return load_plugin(module_name, PluginTypes.PIPELINE) + + +class OVOSPipelineFactory: + """ + Factory class for managing and creating pipeline plugins. + """ + + _CACHE = {} + _MAP = { + "converse": "ovos-converse-pipeline-plugin", + "common_qa": "ovos-common-query-pipeline-plugin", + "fallback_high": "ovos-fallback-pipeline-plugin-high", + "fallback_medium": "ovos-fallback-pipeline-plugin-medium", + "fallback_low": "ovos-fallback-pipeline-plugin-low", + "stop_high": "ovos-stop-pipeline-plugin-high", + "stop_medium": "ovos-stop-pipeline-plugin-medium", + "stop_low": "ovos-stop-pipeline-plugin-low", + "adapt_high": "ovos-adapt-pipeline-plugin-high", + "adapt_medium": "ovos-adapt-pipeline-plugin-medium", + "adapt_low": "ovos-adapt-pipeline-plugin-low", + "padacioso_high": "ovos-padacioso-pipeline-plugin-high", + "padacioso_medium": "ovos-padacioso-pipeline-plugin-medium", + "padacioso_low": "ovos-padacioso-pipeline-plugin-low", + "padatious_high": "ovos-padatious-pipeline-plugin-high", + "padatious_medium": "ovos-padatious-pipeline-plugin-medium", + "padatious_low": "ovos-padatious-pipeline-plugin-low", + "ocp_high": "ovos-ocp-pipeline-plugin-high", + "ocp_medium": "ovos-ocp-pipeline-plugin-medium", + "ocp_low": "ovos-ocp-pipeline-plugin-low", + "ocp_legacy": "ovos-ocp-pipeline-plugin-legacy" + } + + @staticmethod + def get_installed_pipelines() -> List[str]: + """ + Get a list of installed pipelines. + + @return: A list of installed pipeline identifiers. + """ + pipelines = [] + for plug_id, clazz in find_pipeline_plugins().items(): + if issubclass(clazz, ConfidenceMatcherPipeline): + pipelines.append(f"{plug_id}-low") + pipelines.append(f"{plug_id}-medium") + pipelines.append(f"{plug_id}-high") + else: + pipelines.append(plug_id) + return pipelines + + @staticmethod + def get_pipeline_classes(pipeline: Optional[List[str]] = None) -> List[Tuple[str, type(PipelinePlugin)]]: + """ + Get a list of pipeline plugin classes based on the pipeline configuration. + + @param pipeline: A list of pipeline plugin identifiers to load. + @return: A list of tuples containing the plugin identifier and the corresponding plugin class. + """ + default_p = [ + "stop_high", "converse", "ocp_high", "padatious_high", "adapt_high", + "ocp_medium", "fallback_high", "stop_medium", "adapt_medium", + "padatious_medium", "adapt_low", "common_qa", "fallback_medium", "fallback_low" + ] + pipeline = pipeline or Configuration().get("intents", {}).get("pipeline", + [OVOSPipelineFactory._MAP[p] for p in default_p]) + + deprecated = [p for p in pipeline if p in OVOSPipelineFactory._MAP] + if deprecated: + log_deprecation(f"pipeline names have changed, " + f"please migrate: '{deprecated}' to '{[OVOSPipelineFactory._MAP[p] for p in deprecated]}'", + "1.0.0") + + valid_pipeline = [OVOSPipelineFactory._MAP.get(p, p) for p in pipeline] + matchers = [] + for plug_id, clazz in find_pipeline_plugins().items(): + if issubclass(clazz, ConfidenceMatcherPipeline): + if f"{plug_id}-low" in valid_pipeline: + matchers.append((f"{plug_id}-low", clazz)) + if f"{plug_id}-medium" in valid_pipeline: + matchers.append((f"{plug_id}-medium", clazz)) + if f"{plug_id}-high" in valid_pipeline: + matchers.append((f"{plug_id}-high", clazz)) + else: + matchers.append((plug_id, clazz)) + + return matchers + + @staticmethod + def create(pipeline: Optional[List[str]] = None, use_cache: bool = True, + bus: Optional[Union[MessageBusClient, FakeBus]] = None, + skip_stage_matchers: bool = False) -> List[Tuple[str, Callable]]: + """ + Factory method to create pipeline matchers. + + @param pipeline: A list of pipeline plugin identifiers to load. + @param use_cache: Whether to cache the created matchers for reuse. + @param bus: The message bus client to use for the pipelines. + @param skip_stage_matchers: Whether to skip the stage matchers (i.e., matchers with side effects). + @return: A list of tuples containing the pipeline identifier and the matcher callable. + """ + matchers = [] + for pipe_id, clazz in OVOSPipelineFactory.get_pipeline_classes(pipeline): + if use_cache and pipe_id in OVOSPipelineFactory._CACHE: + m = OVOSPipelineFactory._CACHE[pipe_id] + else: + config = Configuration().get("intents", {}).get(pipe_id) + m = clazz(bus, config) + if use_cache: + OVOSPipelineFactory._CACHE[pipe_id] = m + if isinstance(m, ConfidenceMatcherPipeline): + if pipe_id.endswith("-high"): + matchers.append((pipe_id, m.match_high)) + elif pipe_id.endswith("-medium"): + matchers.append((pipe_id, m.match_medium)) + elif pipe_id.endswith("-low"): + matchers.append((pipe_id, m.match_low)) + elif isinstance(m, PipelineStageMatcher) and not skip_stage_matchers: + matchers.append((pipe_id, m.match)) + return matchers + + @staticmethod + def shutdown() -> None: + """ + Shutdown all cached pipeline plugins by calling their shutdown methods if available. + """ + for pipe in OVOSPipelineFactory._CACHE.values(): + if hasattr(pipe, "shutdown"): + try: + pipe.shutdown() + except: + continue diff --git a/ovos_plugin_manager/templates/pipeline.py b/ovos_plugin_manager/templates/pipeline.py index ce85dda7..4afc7e0d 100644 --- a/ovos_plugin_manager/templates/pipeline.py +++ b/ovos_plugin_manager/templates/pipeline.py @@ -1,18 +1,238 @@ +import abc from collections import namedtuple -from typing import Optional, Dict +from dataclasses import dataclass +from typing import Optional, Dict, List, Union -# Intent match response tuple, ovos-core expects PipelinePlugin to return this data structure +from ovos_bus_client.client import MessageBusClient +from ovos_bus_client.message import Message +from ovos_utils.fakebus import FakeBus + +# LEGACY: Intent match response tuple, ovos-core~=0.2 expects PipelinePlugin to return this data structure # intent_service: Name of the service that matched the intent # intent_type: intent name (used to call intent handler over the message bus) # intent_data: data provided by the intent match # skill_id: the skill this handler belongs to +# TODO - deprecated IntentMatch = namedtuple('IntentMatch', ['intent_service', 'intent_type', 'intent_data', 'skill_id', 'utterance'] ) +@dataclass +class IntentHandlerMatch: + """ + Represents an intent handler match result, expected by ovos-core plugins. + + Attributes: + match_type (str): Name of the service that matched the intent. + match_data (Optional[Dict]): Additional data provided by the intent match. + skill_id (Optional[str]): The skill this handler belongs to. + utterance (Optional[str]): The original utterance triggering the intent. + """ + match_type: str + match_data: Optional[Dict] = None + skill_id: Optional[str] = None + utterance: Optional[str] = None + + +@dataclass +class PipelineMatch: + """ + Represents a match in a pipeline that does not trigger an intent message directly. + + Attributes: + match_type (bool): Indicates if the utterance was matched (compatibility only). + handled (bool): Whether the match has already handled the utterance. + match_data (Optional[Dict]): Data provided by the intent match. + skill_id (Optional[str]): The skill this handler belongs to. + utterance (Optional[str]): The original utterance triggering the match. + """ + match_type: bool = True + handled: bool = True + match_data: Optional[Dict] = None + skill_id: Optional[str] = None + utterance: Optional[str] = None + + class PipelinePlugin: - """This class is a placeholder, this API will be defined in ovos-core release 0.1.0""" + """ + Base class for intent matching pipeline plugins. Mainly useful for typing + + Attributes: + config (Dict): Configuration for the plugin. + """ + def __init__(self, config: Optional[Dict] = None): self.config = config or {} + + +class ConfidenceMatcherPipeline(PipelinePlugin): + """ + Base class for plugins that match utterances with confidence levels, + but do not directly trigger actions. + + Example plugins: adapt, padatious. + + Attributes: + bus (Union[MessageBusClient, FakeBus]): The message bus client for communication. + """ + + def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None, + config: Optional[Dict] = None): + self.bus = bus or FakeBus() + super().__init__(config=config) + + @abc.abstractmethod + def match_high(self, utterances: List[str], lang: str, message: Message) -> Optional[IntentMatch]: + """ + Match an utterance with high confidence. + + Args: + utterances (List[str]): List of utterances to match. + lang (str): The language of the utterances. + message (Message): The message containing the utterance. + + Returns: + Optional[IntentMatch]: The match result or None if no match is found. + """ + pass + + @abc.abstractmethod + def match_medium(self, utterances: List[str], lang: str, message: Message) -> Optional[IntentMatch]: + """ + Match an utterance with medium confidence. + + Args: + utterances (List[str]): List of utterances to match. + lang (str): The language of the utterances. + message (Message): The message containing the utterance. + + Returns: + Optional[IntentMatch]: The match result or None if no match is found. + """ + pass + + @abc.abstractmethod + def match_low(self, utterances: List[str], lang: str, message: Message) -> Optional[IntentMatch]: + """ + Match an utterance with low confidence. + + Args: + utterances (List[str]): List of utterances to match. + lang (str): The language of the utterances. + message (Message): The message containing the utterance. + + Returns: + Optional[IntentMatch]: The match result or None if no match is found. + """ + pass + + +class PipelineStageMatcher(PipelinePlugin): + """ + Base class for plugins that consume an utterance during matching, + aborting subsequent pipeline stages if a match is found. + + WARNING: has side effects when match is used + + these plugins will consume an utterance during the match process, + it is not known if this component will match without going through the match process + + Example plugins: converse, common_query. + + Attributes: + bus (Union[MessageBusClient, FakeBus]): The message bus client for communication. + """ + + def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None, + config: Optional[Dict] = None): + self.bus = bus or FakeBus() + super().__init__(config=config) + + @abc.abstractmethod + def match(self, utterances: List[str], lang: str, message: Message) -> Optional[PipelineMatch]: + """ + Match an utterance, potentially aborting further stages in the pipeline. + + Args: + utterances (List[str]): List of utterances to match. + lang (str): The language of the utterances. + message (Message): The message containing the utterance. + + Returns: + Optional[PipelineMatch]: The match result or None if no match is found. + """ + pass + + +class PipelineStageConfidenceMatcher(PipelineStageMatcher, ConfidenceMatcherPipeline): + """ + A specialized matcher that consumes utterances during the matching process + and supports confidence levels. It aborts further pipeline stages if a match is found. + + Example plugins: fallback, stop. + + Attributes: + bus (Union[MessageBusClient, FakeBus]): The message bus client for communication. + """ + + def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None, + config: Optional[Dict] = None): + super().__init__(bus=bus, config=config) + + def match(self, utterances: List[str], lang: str, message: Message) -> Optional[PipelineMatch]: + """ + Match an utterance using high confidence, with no specific match level defined. + + Args: + utterances (List[str]): List of utterances to match. + lang (str): The language of the utterances. + message (Message): The message containing the utterance. + + Returns: + Optional[PipelineMatch]: The match result or None if no match is found. + """ + return self.match_high(utterances, lang, message) + + @abc.abstractmethod + def match_high(self, utterances: List[str], lang: str, message: Message) -> Optional[PipelineMatch]: + """ + Match an utterance with high confidence. + + Args: + utterances (List[str]): List of utterances to match. + lang (str): The language of the utterances. + message (Message): The message containing the utterance. + + Returns: + Optional[IntentMatch]: The match result or None if no match is found. + """ + + @abc.abstractmethod + def match_medium(self, utterances: List[str], lang: str, message: Message) -> Optional[PipelineMatch]: + """ + Match an utterance with medium confidence. + + Args: + utterances (List[str]): List of utterances to match. + lang (str): The language of the utterances. + message (Message): The message containing the utterance. + + Returns: + Optional[IntentMatch]: The match result or None if no match is found. + """ + + @abc.abstractmethod + def match_low(self, utterances: List[str], lang: str, message: Message) -> Optional[PipelineMatch]: + """ + Match an utterance with low confidence. + + Args: + utterances (List[str]): List of utterances to match. + lang (str): The language of the utterances. + message (Message): The message containing the utterance. + + Returns: + Optional[IntentMatch]: The match result or None if no match is found. + """