From 1260146b7cbae13086e2dd0640ba300cfacc5817 Mon Sep 17 00:00:00 2001 From: nsheff Date: Mon, 6 Nov 2023 21:09:39 -0500 Subject: [PATCH] Disallow overriding targets. Fix #22 --- markmeld/__init__.py | 4 ++-- markmeld/cli.py | 4 ++-- markmeld/utilities.py | 44 +++++++++++++++++++++++++++--------------- tests/test_markmeld.py | 16 +++++++-------- 4 files changed, 40 insertions(+), 28 deletions(-) diff --git a/markmeld/__init__.py b/markmeld/__init__.py index 6ed6fe1..fab8789 100644 --- a/markmeld/__init__.py +++ b/markmeld/__init__.py @@ -1,8 +1,8 @@ from .melder import MarkdownMelder from .cli import main -from .utilities import load_config_file +from .utilities import load_config_file, load_config_wrapper -__all__ = ["MarkdownMelder", "load_config_file"] +__all__ = ["MarkdownMelder", "load_config_file", "load_config_wrapper"] if __name__ == "__main__": try: diff --git a/markmeld/cli.py b/markmeld/cli.py index b50f040..832202d 100644 --- a/markmeld/cli.py +++ b/markmeld/cli.py @@ -8,7 +8,7 @@ from .exceptions import * from .melder import MarkdownMelder -from .utilities import load_config_file, get_file_open_cmd +from .utilities import load_config_wrapper, get_file_open_cmd from ._version import __version__ tpl = """imports: null @@ -148,7 +148,7 @@ def main(test_args=None): _LOGGER.error(msg) raise ConfigError(msg) - cfg = load_config_file(args.config, None, args.autocomplete) + cfg = load_config_wrapper(args.config, None, args.autocomplete) if args.autocomplete: if "targets" not in cfg: diff --git a/markmeld/utilities.py b/markmeld/utilities.py index cb75e76..6677ff2 100644 --- a/markmeld/utilities.py +++ b/markmeld/utilities.py @@ -13,7 +13,6 @@ _LOGGER = getLogger(PKG_NAME) - # define some useful functions def recursive_get(dat, indices): """ @@ -75,7 +74,15 @@ def format_command(tgt): # 2. the location of where it should be executed (workpath) # These are not the same thing. -def load_config_file(filepath, workpath=None, autocomplete=True): +def load_config_wrapper(cfg_path, workpath=None, autocomplete=True): + """ + Wrapper function that maintains a list of imported files, to prevent duplicate imports. + """ + imported_list = {} + return load_config_file(cfg_path, workpath, autocomplete, imported_list) + + +def load_config_file(filepath, workpath=None, autocomplete=True, imported_list={}): """ Loads a configuration file. @@ -83,27 +90,34 @@ def load_config_file(filepath, workpath=None, autocomplete=True): @param str workpath The working path that the target's relative paths are relative to @return dict Loaded yaml data object. """ - + _LOGGER.debug(f"Loading config file: {filepath}") + _LOGGER.debug(f"Imported list: {imported_list}") + if imported_list.get(filepath): + _LOGGER.debug(f"Already imported: {filepath}") + return {} try: with open(filepath, "r") as f: cfg_data = f.read() return load_config_data( - cfg_data, os.path.abspath(filepath), workpath, autocomplete + cfg_data, os.path.abspath(filepath), workpath, autocomplete, imported_list ) + except FileNotFoundError as e: + _LOGGER.error(f"Couldn't load config file: {filepath} because: {repr(e)}") + return {} # Allow continuing if file not found except Exception as e: _LOGGER.error(f"Couldn't load config file: {filepath} because: {repr(e)}") - return {} + raise e # Fail on other errors def make_abspath(relpath, filepath, root=None): if root: return os.path.join(root, relpath) - return os.path.join(os.path.dirname(filepath), relpath) + return os.path.abspath(os.path.join(os.path.dirname(filepath), relpath)) -def load_config_data(cfg_data, filepath=None, workpath=None, autocomplete=True): +def load_config_data(cfg_data, filepath=None, workpath=None, autocomplete=True, imported_list={}): """ - Recursive loader that parses a yaml string, handles imports, and runs target factories. + Recursive loader that parses a yaml string, handles imports, and runs target factories to create targets. """ higher_cfg = yaml.load(cfg_data, Loader=yaml.SafeLoader) higher_cfg["_cfg_file_path"] = filepath @@ -115,7 +129,6 @@ def load_config_data(cfg_data, filepath=None, workpath=None, autocomplete=True): if "targets" in higher_cfg: for tgt in higher_cfg["targets"]: higher_cfg["targets"][tgt]["_defpath"] = filepath - _LOGGER.debug(tgt, higher_cfg["targets"][tgt]) if workpath: higher_cfg["targets"][tgt]["_workpath"] = workpath else: @@ -125,9 +138,7 @@ def load_config_data(cfg_data, filepath=None, workpath=None, autocomplete=True): if "imports" in higher_cfg and higher_cfg["imports"]: _LOGGER.debug("Found imports") for import_file in higher_cfg["imports"]: - import_file_abspath = os.path.relpath( - make_abspath(expandpath(import_file), expandpath(filepath)) - ) + import_file_abspath = make_abspath(expandpath(import_file), expandpath(filepath)) if not autocomplete: _LOGGER.info(f"Specified config file to import: {import_file_abspath}") deep_update( @@ -135,13 +146,12 @@ def load_config_data(cfg_data, filepath=None, workpath=None, autocomplete=True): load_config_file(import_file_abspath, expandpath(filepath)), warn_override=not autocomplete, ) + imported_list[import_file_abspath] = True if "imports_relative" in higher_cfg and higher_cfg["imports_relative"]: _LOGGER.debug("Found relative imports") for import_file in higher_cfg["imports_relative"]: - import_file_abspath = os.path.relpath( - make_abspath(expandpath(import_file), expandpath(filepath)) - ) + import_file_abspath = make_abspath(expandpath(import_file), expandpath(filepath)) if not autocomplete: _LOGGER.info( f"Specified relative config file to import (relative): {import_file}" @@ -151,6 +161,7 @@ def load_config_data(cfg_data, filepath=None, workpath=None, autocomplete=True): load_config_file(expandpath(import_file_abspath)), warn_override=not autocomplete, ) + imported_list[import_file_abspath] = True deep_update(lower_cfg, higher_cfg, warn_override=not autocomplete) @@ -183,9 +194,10 @@ def warn_overriding_target(old, new): _LOGGER.error(f"Overriding target: {tgt}") _LOGGER.error("Originally defined in: ".rjust(27, " ") + f"{old['targets'][tgt]['_defpath']}") _LOGGER.error("Redefined in: ".rjust(27, " ") + f"{new['targets'][tgt]['_defpath']}") + raise Exception("Same target name is defined in imported file. Overriding targets is not allowed.") -def deep_update(old, new, warn_override=False): +def deep_update(old, new, warn_override=True): """ Like built-in dict update, but recursive. """ diff --git a/tests/test_markmeld.py b/tests/test_markmeld.py index 6886f60..c2e9d9f 100644 --- a/tests/test_markmeld.py +++ b/tests/test_markmeld.py @@ -140,14 +140,14 @@ def test_inherited_data_propogates_to_target(): def test_import(): - cfg = markmeld.load_config_file("tests/test_data/_markmeld_import.yaml") + cfg = markmeld.load_config_wrapper("tests/test_data/_markmeld_import.yaml") mm = markmeld.MarkdownMelder(cfg) res = mm.build_target("imported_target", print_only=True) print(res.melded_output) assert "rVEeqUQ1t5" in str(res.melded_output) - cfg2 = markmeld.load_config_file("tests/test_data/_markmeld_import_relative.yaml") + cfg2 = markmeld.load_config_wrapper("tests/test_data/_markmeld_import_relative.yaml") mm2 = markmeld.MarkdownMelder(cfg2) res = mm2.build_target("imported_target", print_only=True) @@ -156,7 +156,7 @@ def test_import(): def test_null_jinja_template(): - cfg = markmeld.load_config_file( + cfg = markmeld.load_config_wrapper( "tests/test_data/_markmeld_null_jinja_template.yaml" ) mm = markmeld.MarkdownMelder(cfg) @@ -164,15 +164,15 @@ def test_null_jinja_template(): def test_variable_variables(): - cfg = markmeld.load_config_file("demo_book/book_basic/_markmeld.yaml") + cfg = markmeld.load_config_wrapper("demo_book/book_basic/_markmeld.yaml") mm = markmeld.MarkdownMelder(cfg) res = mm.build_target("default", print_only=True) - cfg2 = markmeld.load_config_file("demo_book/book_var1/_markmeld.yaml") + cfg2 = markmeld.load_config_wrapper("demo_book/book_var1/_markmeld.yaml") mm2 = markmeld.MarkdownMelder(cfg2) res2 = mm2.build_target("default", print_only=True) - cfg3 = markmeld.load_config_file("demo_book/book_var2/_markmeld.yaml") + cfg3 = markmeld.load_config_wrapper("demo_book/book_var2/_markmeld.yaml") mm3 = markmeld.MarkdownMelder(cfg3) res3 = mm3.build_target("default", print_only=True) @@ -180,7 +180,7 @@ def test_variable_variables(): # print("/////" + res2.melded_output + "/////") # print("/////" + res3.melded_output + "/////") - cfg4 = markmeld.load_config_file("demo_book/variable_variables/_markmeld.yaml") + cfg4 = markmeld.load_config_wrapper("demo_book/variable_variables/_markmeld.yaml") mm4 = markmeld.MarkdownMelder(cfg4) res4 = mm4.build_target("default", print_only=True) @@ -190,7 +190,7 @@ def test_variable_variables(): def test_meta_target(): - cfg = markmeld.load_config_file("tests/test_data/prebuild_test/_markmeld.yaml") + cfg = markmeld.load_config_wrapper("tests/test_data/prebuild_test/_markmeld.yaml") mm = markmeld.MarkdownMelder(cfg) res = mm.build_target("my_meta_target", print_only=True) test_path = "tests/test_data/prebuild_test/prebuild_test_file"