Skip to content

Commit

Permalink
Disallow overriding targets. Fix #22
Browse files Browse the repository at this point in the history
  • Loading branch information
nsheff committed Nov 7, 2023
1 parent 27e9c07 commit 1260146
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 28 deletions.
4 changes: 2 additions & 2 deletions markmeld/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .melder import MarkdownMelder
from .cli import main
from .utilities import load_config_file
from .utilities import load_config_file, load_config_wrapper

__all__ = ["MarkdownMelder", "load_config_file"]
__all__ = ["MarkdownMelder", "load_config_file", "load_config_wrapper"]

if __name__ == "__main__":
try:
Expand Down
4 changes: 2 additions & 2 deletions markmeld/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .exceptions import *
from .melder import MarkdownMelder
from .utilities import load_config_file, get_file_open_cmd
from .utilities import load_config_wrapper, get_file_open_cmd
from ._version import __version__

tpl = """imports: null
Expand Down Expand Up @@ -148,7 +148,7 @@ def main(test_args=None):
_LOGGER.error(msg)
raise ConfigError(msg)

cfg = load_config_file(args.config, None, args.autocomplete)
cfg = load_config_wrapper(args.config, None, args.autocomplete)

if args.autocomplete:
if "targets" not in cfg:
Expand Down
44 changes: 28 additions & 16 deletions markmeld/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

_LOGGER = getLogger(PKG_NAME)


# define some useful functions
def recursive_get(dat, indices):
"""
Expand Down Expand Up @@ -75,35 +74,50 @@ def format_command(tgt):
# 2. the location of where it should be executed (workpath)
# These are not the same thing.

def load_config_file(filepath, workpath=None, autocomplete=True):
def load_config_wrapper(cfg_path, workpath=None, autocomplete=True):
"""
Wrapper function that maintains a list of imported files, to prevent duplicate imports.
"""
imported_list = {}
return load_config_file(cfg_path, workpath, autocomplete, imported_list)


def load_config_file(filepath, workpath=None, autocomplete=True, imported_list={}):
"""
Loads a configuration file.
@param str filepath Path to configuration file to load
@param str workpath The working path that the target's relative paths are relative to
@return dict Loaded yaml data object.
"""

_LOGGER.debug(f"Loading config file: {filepath}")
_LOGGER.debug(f"Imported list: {imported_list}")
if imported_list.get(filepath):
_LOGGER.debug(f"Already imported: {filepath}")
return {}
try:
with open(filepath, "r") as f:
cfg_data = f.read()
return load_config_data(
cfg_data, os.path.abspath(filepath), workpath, autocomplete
cfg_data, os.path.abspath(filepath), workpath, autocomplete, imported_list
)
except FileNotFoundError as e:
_LOGGER.error(f"Couldn't load config file: {filepath} because: {repr(e)}")
return {} # Allow continuing if file not found
except Exception as e:
_LOGGER.error(f"Couldn't load config file: {filepath} because: {repr(e)}")
return {}
raise e # Fail on other errors


def make_abspath(relpath, filepath, root=None):
if root:
return os.path.join(root, relpath)
return os.path.join(os.path.dirname(filepath), relpath)
return os.path.abspath(os.path.join(os.path.dirname(filepath), relpath))


def load_config_data(cfg_data, filepath=None, workpath=None, autocomplete=True):
def load_config_data(cfg_data, filepath=None, workpath=None, autocomplete=True, imported_list={}):
"""
Recursive loader that parses a yaml string, handles imports, and runs target factories.
Recursive loader that parses a yaml string, handles imports, and runs target factories to create targets.
"""
higher_cfg = yaml.load(cfg_data, Loader=yaml.SafeLoader)
higher_cfg["_cfg_file_path"] = filepath
Expand All @@ -115,7 +129,6 @@ def load_config_data(cfg_data, filepath=None, workpath=None, autocomplete=True):
if "targets" in higher_cfg:
for tgt in higher_cfg["targets"]:
higher_cfg["targets"][tgt]["_defpath"] = filepath
_LOGGER.debug(tgt, higher_cfg["targets"][tgt])
if workpath:
higher_cfg["targets"][tgt]["_workpath"] = workpath
else:
Expand All @@ -125,23 +138,20 @@ def load_config_data(cfg_data, filepath=None, workpath=None, autocomplete=True):
if "imports" in higher_cfg and higher_cfg["imports"]:
_LOGGER.debug("Found imports")
for import_file in higher_cfg["imports"]:
import_file_abspath = os.path.relpath(
make_abspath(expandpath(import_file), expandpath(filepath))
)
import_file_abspath = make_abspath(expandpath(import_file), expandpath(filepath))
if not autocomplete:
_LOGGER.info(f"Specified config file to import: {import_file_abspath}")
deep_update(
lower_cfg,
load_config_file(import_file_abspath, expandpath(filepath)),
warn_override=not autocomplete,
)
imported_list[import_file_abspath] = True

if "imports_relative" in higher_cfg and higher_cfg["imports_relative"]:
_LOGGER.debug("Found relative imports")
for import_file in higher_cfg["imports_relative"]:
import_file_abspath = os.path.relpath(
make_abspath(expandpath(import_file), expandpath(filepath))
)
import_file_abspath = make_abspath(expandpath(import_file), expandpath(filepath))
if not autocomplete:
_LOGGER.info(
f"Specified relative config file to import (relative): {import_file}"
Expand All @@ -151,6 +161,7 @@ def load_config_data(cfg_data, filepath=None, workpath=None, autocomplete=True):
load_config_file(expandpath(import_file_abspath)),
warn_override=not autocomplete,
)
imported_list[import_file_abspath] = True

deep_update(lower_cfg, higher_cfg, warn_override=not autocomplete)

Expand Down Expand Up @@ -183,9 +194,10 @@ def warn_overriding_target(old, new):
_LOGGER.error(f"Overriding target: {tgt}")
_LOGGER.error("Originally defined in: ".rjust(27, " ") + f"{old['targets'][tgt]['_defpath']}")
_LOGGER.error("Redefined in: ".rjust(27, " ") + f"{new['targets'][tgt]['_defpath']}")
raise Exception("Same target name is defined in imported file. Overriding targets is not allowed.")


def deep_update(old, new, warn_override=False):
def deep_update(old, new, warn_override=True):
"""
Like built-in dict update, but recursive.
"""
Expand Down
16 changes: 8 additions & 8 deletions tests/test_markmeld.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,14 @@ def test_inherited_data_propogates_to_target():


def test_import():
cfg = markmeld.load_config_file("tests/test_data/_markmeld_import.yaml")
cfg = markmeld.load_config_wrapper("tests/test_data/_markmeld_import.yaml")
mm = markmeld.MarkdownMelder(cfg)

res = mm.build_target("imported_target", print_only=True)
print(res.melded_output)
assert "rVEeqUQ1t5" in str(res.melded_output)

cfg2 = markmeld.load_config_file("tests/test_data/_markmeld_import_relative.yaml")
cfg2 = markmeld.load_config_wrapper("tests/test_data/_markmeld_import_relative.yaml")
mm2 = markmeld.MarkdownMelder(cfg2)

res = mm2.build_target("imported_target", print_only=True)
Expand All @@ -156,31 +156,31 @@ def test_import():


def test_null_jinja_template():
cfg = markmeld.load_config_file(
cfg = markmeld.load_config_wrapper(
"tests/test_data/_markmeld_null_jinja_template.yaml"
)
mm = markmeld.MarkdownMelder(cfg)
res = mm.build_target("target_name", print_only=True)


def test_variable_variables():
cfg = markmeld.load_config_file("demo_book/book_basic/_markmeld.yaml")
cfg = markmeld.load_config_wrapper("demo_book/book_basic/_markmeld.yaml")
mm = markmeld.MarkdownMelder(cfg)
res = mm.build_target("default", print_only=True)

cfg2 = markmeld.load_config_file("demo_book/book_var1/_markmeld.yaml")
cfg2 = markmeld.load_config_wrapper("demo_book/book_var1/_markmeld.yaml")
mm2 = markmeld.MarkdownMelder(cfg2)
res2 = mm2.build_target("default", print_only=True)

cfg3 = markmeld.load_config_file("demo_book/book_var2/_markmeld.yaml")
cfg3 = markmeld.load_config_wrapper("demo_book/book_var2/_markmeld.yaml")
mm3 = markmeld.MarkdownMelder(cfg3)
res3 = mm3.build_target("default", print_only=True)

# print("/////" + res.melded_output + "/////")
# print("/////" + res2.melded_output + "/////")
# print("/////" + res3.melded_output + "/////")

cfg4 = markmeld.load_config_file("demo_book/variable_variables/_markmeld.yaml")
cfg4 = markmeld.load_config_wrapper("demo_book/variable_variables/_markmeld.yaml")
mm4 = markmeld.MarkdownMelder(cfg4)
res4 = mm4.build_target("default", print_only=True)

Expand All @@ -190,7 +190,7 @@ def test_variable_variables():


def test_meta_target():
cfg = markmeld.load_config_file("tests/test_data/prebuild_test/_markmeld.yaml")
cfg = markmeld.load_config_wrapper("tests/test_data/prebuild_test/_markmeld.yaml")
mm = markmeld.MarkdownMelder(cfg)
res = mm.build_target("my_meta_target", print_only=True)
test_path = "tests/test_data/prebuild_test/prebuild_test_file"
Expand Down

0 comments on commit 1260146

Please sign in to comment.