diff --git a/ovos_plugin_manager/utils/config.py b/ovos_plugin_manager/utils/config.py index 25416537..3746f455 100644 --- a/ovos_plugin_manager/utils/config.py +++ b/ovos_plugin_manager/utils/config.py @@ -19,8 +19,9 @@ def get_plugin_config(config: Optional[dict] = None, section: str = None, module = module or config.get('module') if module: module_config = config.get(module) or dict() - module_config.setdefault('lang', lang) module_config.setdefault('module', module) + if section not in ["hotwords", "VAD", "listener"]: + module_config.setdefault('lang', lang) return module_config if section not in ["hotwords", "VAD", "listener"]: config.setdefault('lang', lang) diff --git a/ovos_plugin_manager/vad.py b/ovos_plugin_manager/vad.py index cfed63db..d9cdf49f 100644 --- a/ovos_plugin_manager/vad.py +++ b/ovos_plugin_manager/vad.py @@ -44,12 +44,14 @@ def get_class(config=None): The configuration file ``mycroft.conf`` contains a ``vad`` section with the name of a VAD module to be read by this method. - "vad": { + "VAD": { "module": } """ config = get_vad_config(config) - vad_module = config.get("module", "dummy") + vad_module = config.get("module") + if not vad_module: + raise ValueError(f"VAD Plugin not configured in: {config}") if vad_module == "dummy": return VADEngine if vad_module in OVOSVADFactory.MAPPINGS: @@ -60,29 +62,36 @@ def get_class(config=None): def create(config=None): """Factory method to create a VAD engine based on configuration. - The configuration file ``mycroft.conf`` contains a ``vad`` section with + The configuration file ``mycroft.conf`` contains a ``VAD`` section with the name of a VAD module to be read by this method. - "vad": { + "VAD": { "module": } """ - config = config or get_vad_config() - plugin = config.get("module") or "dummy" - plugin_config = config.get(plugin) or {} + vad_config = get_vad_config(config) + plugin = vad_config.get("module") + if not plugin: + raise ValueError(f"VAD Plugin not configured in: {vad_config}") try: - clazz = OVOSVADFactory.get_class(config) - return clazz(plugin_config) + clazz = OVOSVADFactory.get_class(vad_config) + return clazz(vad_config) except Exception: LOG.exception(f'VAD plugin {plugin} could not be loaded!') raise def get_vad_config(config=None): + """ + Get the VAD configuration, including `module` and module-specific config + @param config: Configuration dict to parse (default Configuration()) + @return: dict containing `module` and module-specific configuration + """ from ovos_plugin_manager.utils.config import get_plugin_config config = config or Configuration() if "listener" in config and "VAD" not in config: - return get_plugin_config(config, "listener") - else: - return get_plugin_config(config, "VAD") + config = get_plugin_config(config, "listener") + if "VAD" in config: + config = get_plugin_config(config, "VAD") + return config diff --git a/test/unittests/test_vad.py b/test/unittests/test_vad.py new file mode 100644 index 00000000..fcb90531 --- /dev/null +++ b/test/unittests/test_vad.py @@ -0,0 +1,84 @@ +import unittest +from unittest.mock import patch, Mock +from copy import copy, deepcopy +from ovos_plugin_manager import PluginTypes + +_TEST_CONFIG = { + "lang": "en-us", + "listener": { + "VAD": { + "module": "dummy", + "dummy": { + "vad_param": True + }, + "ovos-vad-plugin-webrtcvad": { + "vad_mode": 2 + } + } + } +} + + +class TestVADFactory(unittest.TestCase): + def test_create(self): + from ovos_plugin_manager.vad import OVOSVADFactory + real_get_class = OVOSVADFactory.get_class + mock_class = Mock() + + mock_get_class = Mock(return_value=mock_class) + OVOSVADFactory.get_class = mock_get_class + + OVOSVADFactory.create(config=_TEST_CONFIG) + mock_get_class.assert_called_once_with( + {**_TEST_CONFIG['listener']['VAD']['dummy'], **{"module": "dummy"}}) + mock_class.assert_called_once_with( + _TEST_CONFIG['listener']["VAD"]['dummy']) + + # Test invalid config + with self.assertRaises(ValueError): + OVOSVADFactory.create({'VAD': {'value': None}}) + + OVOSVADFactory.get_class = real_get_class + + @patch("ovos_plugin_manager.vad.load_plugin") + def test_get_class(self, load_plugin): + mock = Mock() + load_plugin.return_value = mock + from ovos_plugin_manager.vad import OVOSVADFactory + from ovos_plugin_manager.templates.vad import VADEngine + + # Test invalid config + with self.assertRaises(ValueError): + OVOSVADFactory.get_class({'module': None}) + + # Test dummy module + module = OVOSVADFactory.get_class(_TEST_CONFIG) + load_plugin.assert_not_called() + self.assertEqual(VADEngine, module) + + # Test valid module + config = deepcopy(_TEST_CONFIG) + config['listener']['VAD']['module'] = 'ovos-vad-plugin-webrtcvad' + module = OVOSVADFactory.get_class(config) + load_plugin.assert_called_once_with('ovos-vad-plugin-webrtcvad', + PluginTypes.VAD) + self.assertEqual(module, mock) + + def test_get_vad_config(self): + from ovos_plugin_manager.vad import get_vad_config + config = copy(_TEST_CONFIG) + dummy_config = get_vad_config(config) + self.assertEqual(dummy_config, + {**_TEST_CONFIG['listener']['VAD']['dummy'], + **{'module': 'dummy'}}) + config = copy(_TEST_CONFIG) + config['listener']['VAD']['module'] = 'ovos-vad-plugin-webrtcvad' + webrtc_config = get_vad_config(config) + self.assertEqual(webrtc_config, + {**_TEST_CONFIG['listener']['VAD'] + ['ovos-vad-plugin-webrtcvad'], + **{'module': 'ovos-vad-plugin-webrtcvad'}}) + config = copy(_TEST_CONFIG) + config['VAD'] = {'module': 'fake'} + fake_config = get_vad_config(config) + self.assertEqual(fake_config, {'module': 'fake'})