From 0b8ea1e29875f8ff463a93007f8c3fa10da26533 Mon Sep 17 00:00:00 2001 From: miro Date: Wed, 16 Oct 2024 09:49:08 +0100 Subject: [PATCH] feat:pipeline plugin factory --- ovos_plugin_manager/pipeline.py | 118 ++++++++++++++++++-------------- 1 file changed, 68 insertions(+), 50 deletions(-) diff --git a/ovos_plugin_manager/pipeline.py b/ovos_plugin_manager/pipeline.py index f201428..ec56555 100644 --- a/ovos_plugin_manager/pipeline.py +++ b/ovos_plugin_manager/pipeline.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Callable, Union +from typing import List, Optional, Tuple, Callable, Union, Dict, Type from ovos_bus_client.client import MessageBusClient from ovos_config import Configuration @@ -9,7 +9,7 @@ from ovos_plugin_manager.utils import PluginTypes -def find_pipeline_plugins() -> dict: +def find_pipeline_plugins() -> Dict[str, Type[PipelinePlugin]]: """ Find all installed plugins @return: dict plugin names to entrypoints @@ -18,21 +18,54 @@ def find_pipeline_plugins() -> dict: return find_plugins(PluginTypes.PIPELINE) -def load_pipeline_plugin(module_name: str) -> type(PipelinePlugin): +def load_pipeline_plugin(module_name: str) -> Type[PipelinePlugin]: """ - Get an uninstantiated class for the requested module_name - @param module_name: Plugin entrypoint name to load - @return: Uninstantiated class + Load and return an uninstantiated class for the specified pipeline plugin. + + @param module_name: The name of the plugin to load. + @return: The uninstantiated plugin class. """ from ovos_plugin_manager.utils import load_plugin return load_plugin(module_name, PluginTypes.PIPELINE) class OVOSPipelineFactory: + """ + Factory class for managing and creating pipeline plugins. + """ + _CACHE = {} + _MAP = { + "converse": "ovos-converse-pipeline-plugin", + "common_qa": "ovos-common-query-pipeline-plugin", + "fallback_high": "ovos-fallback-pipeline-plugin-high", + "fallback_medium": "ovos-fallback-pipeline-plugin-medium", + "fallback_low": "ovos-fallback-pipeline-plugin-low", + "stop_high": "ovos-stop-pipeline-plugin-high", + "stop_medium": "ovos-stop-pipeline-plugin-medium", + "stop_low": "ovos-stop-pipeline-plugin-low", + "adapt_high": "ovos-adapt-pipeline-plugin-high", + "adapt_medium": "ovos-adapt-pipeline-plugin-medium", + "adapt_low": "ovos-adapt-pipeline-plugin-low", + "padacioso_high": "ovos-padacioso-pipeline-plugin-high", + "padacioso_medium": "ovos-padacioso-pipeline-plugin-medium", + "padacioso_low": "ovos-padacioso-pipeline-plugin-low", + "padatious_high": "ovos-padatious-pipeline-plugin-high", + "padatious_medium": "ovos-padatious-pipeline-plugin-medium", + "padatious_low": "ovos-padatious-pipeline-plugin-low", + "ocp_high": "ovos-ocp-pipeline-plugin-high", + "ocp_medium": "ovos-ocp-pipeline-plugin-medium", + "ocp_low": "ovos-ocp-pipeline-plugin-low", + "ocp_legacy": "ovos-ocp-pipeline-plugin-legacy" + } @staticmethod def get_installed_pipelines() -> List[str]: + """ + Get a list of installed pipelines. + + @return: A list of installed pipeline identifiers. + """ pipelines = [] for plug_id, clazz in find_pipeline_plugins().items(): if issubclass(clazz, ConfidenceMatcherPipeline): @@ -45,53 +78,27 @@ def get_installed_pipelines() -> List[str]: @staticmethod def get_pipeline_classes(pipeline: Optional[List[str]] = None) -> List[Tuple[str, type(PipelinePlugin)]]: - MAP = { - "converse": "ovos-converse-pipeline-plugin", - "common_qa": "ovos-common-query-pipeline-plugin", - "fallback_high": "ovos-fallback-pipeline-plugin-high", - "fallback_medium": "ovos-fallback-pipeline-plugin-medium", - "fallback_low": "ovos-fallback-pipeline-plugin-low", - "stop_high": "ovos-stop-pipeline-plugin-high", - "stop_medium": "ovos-stop-pipeline-plugin-medium", - "stop_low": "ovos-stop-pipeline-plugin-low", - "adapt_high": "ovos-adapt-pipeline-plugin-high", - "adapt_medium": "ovos-adapt-pipeline-plugin-medium", - "adapt_low": "ovos-adapt-pipeline-plugin-low", - "padacioso_high": "ovos-padacioso-pipeline-plugin-high", - "padacioso_medium": "ovos-padacioso-pipeline-plugin-medium", - "padacioso_low": "ovos-padacioso-pipeline-plugin-low", - "padatious_high": "ovos-padatious-pipeline-plugin-high", - "padatious_medium": "ovos-padatious-pipeline-plugin-medium", - "padatious_low": "ovos-padatious-pipeline-plugin-low", - "ocp_high": "ovos-ocp-pipeline-plugin-high", - "ocp_medium": "ovos-ocp-pipeline-plugin-medium", - "ocp_low": "ovos-ocp-pipeline-plugin-low", - "ocp_legacy": "ovos-ocp-pipeline-plugin-legacy" - } + """ + Get a list of pipeline plugin classes based on the pipeline configuration. + + @param pipeline: A list of pipeline plugin identifiers to load. + @return: A list of tuples containing the plugin identifier and the corresponding plugin class. + """ default_p = [ - "stop_high", - "converse", - "ocp_high", - "padatious_high", - "adapt_high", - "ocp_medium", - "fallback_high", - "stop_medium", - "adapt_medium", - "padatious_medium", - "adapt_low", - "common_qa", - "fallback_medium", - "fallback_low" + "stop_high", "converse", "ocp_high", "padatious_high", "adapt_high", + "ocp_medium", "fallback_high", "stop_medium", "adapt_medium", + "padatious_medium", "adapt_low", "common_qa", "fallback_medium", "fallback_low" ] - pipeline = pipeline or Configuration().get("intents", {}).get("pipeline", [MAP[p] for p in default_p]) + pipeline = pipeline or Configuration().get("intents", {}).get("pipeline", + [OVOSPipelineFactory._MAP[p] for p in default_p]) - deprecated = [p for p in pipeline if p in MAP] + deprecated = [p for p in pipeline if p in OVOSPipelineFactory._MAP] if deprecated: log_deprecation(f"pipeline names have changed, " - f"please migrate: '{deprecated}' to '{[MAP[p] for p in deprecated]}'", "1.0.0") + f"please migrate: '{deprecated}' to '{[OVOSPipelineFactory._MAP[p] for p in deprecated]}'", + "1.0.0") - valid_pipeline = [MAP.get(p, p) for p in pipeline] + valid_pipeline = [OVOSPipelineFactory._MAP.get(p, p) for p in pipeline] matchers = [] for plug_id, clazz in find_pipeline_plugins().items(): if issubclass(clazz, ConfidenceMatcherPipeline): @@ -100,7 +107,7 @@ def get_pipeline_classes(pipeline: Optional[List[str]] = None) -> List[Tuple[str if f"{plug_id}-medium" in valid_pipeline: matchers.append((f"{plug_id}-medium", clazz)) if f"{plug_id}-high" in valid_pipeline: - matchers.append((plug_id, clazz)) + matchers.append((f"{plug_id}-high", clazz)) else: matchers.append((plug_id, clazz)) @@ -110,7 +117,15 @@ def get_pipeline_classes(pipeline: Optional[List[str]] = None) -> List[Tuple[str def create(pipeline: Optional[List[str]] = None, use_cache: bool = True, bus: Optional[Union[MessageBusClient, FakeBus]] = None, skip_stage_matchers: bool = False) -> List[Tuple[str, Callable]]: - """Factory method to create pipeline matchers""" + """ + Factory method to create pipeline matchers. + + @param pipeline: A list of pipeline plugin identifiers to load. + @param use_cache: Whether to cache the created matchers for reuse. + @param bus: The message bus client to use for the pipelines. + @param skip_stage_matchers: Whether to skip the stage matchers (i.e., matchers with side effects). + @return: A list of tuples containing the pipeline identifier and the matcher callable. + """ matchers = [] for pipe_id, clazz in OVOSPipelineFactory.get_pipeline_classes(pipeline): if use_cache and pipe_id in OVOSPipelineFactory._CACHE: @@ -132,7 +147,10 @@ def create(pipeline: Optional[List[str]] = None, use_cache: bool = True, return matchers @staticmethod - def shutdown(): + def shutdown() -> None: + """ + Shutdown all cached pipeline plugins by calling their shutdown methods if available. + """ for pipe in OVOSPipelineFactory._CACHE.values(): if hasattr(pipe, "shutdown"): try: