Skip to content

Commit

Permalink
Merge pull request #49 from csdms/mcflugen/order-methods
Browse files Browse the repository at this point in the history
Order methods in generated code by BMI group type
  • Loading branch information
mcflugen authored Jan 17, 2024
2 parents 18f359a + a33a395 commit b520307
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 16 deletions.
26 changes: 18 additions & 8 deletions src/bmipy/_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,29 @@
from bmipy._version import __version__


def main(args: tuple[str, ...] | None = None) -> int:
def main(argv: tuple[str, ...] | None = None) -> int:
"""Render a template BMI implementation in Python for class NAME."""
parser = argparse.ArgumentParser()
parser.add_argument("--version", action="version", version=f"bmipy {__version__}")
parser.add_argument("name")

parsed_args = parser.parse_args(args)

if parsed_args.name.isidentifier() and not keyword.iskeyword(parsed_args.name):
print(Template(parsed_args.name).render())
parser.add_argument("name", metavar="NAME", help="Name of the generated BMI class")

group = parser.add_mutually_exclusive_group()
group.add_argument(
"--docstring",
action="store_true",
dest="docstring",
default=True,
help="Add docstrings to the generated methods (default: include docstrings)",
)
group.add_argument("--no-docstring", action="store_false", dest="docstring")

args = parser.parse_args(argv)

if args.name.isidentifier() and not keyword.iskeyword(args.name):
print(Template(args.name).render(with_docstring=args.docstring))
else:
print(
f"💥 💔 💥 {parsed_args.name!r} is not a valid class name in Python",
f"💥 💔 💥 {args.name!r} is not a valid class name in Python",
file=sys.stderr,
)
return 1
Expand Down
57 changes: 49 additions & 8 deletions src/bmipy/_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,40 @@

import inspect
import os
import re
import textwrap
from collections import defaultdict
from collections import OrderedDict

from bmipy.bmi import Bmi

GROUPS = (
("initialize", "initialize"),
("update", "(update|update_until)"),
("finalize", "finalize"),
("info", r"(get_component_name|\w+_var_names|\w+_item_count)"),
("var", r"get_var_\w+"),
("time", r"get_\w*time\w*"),
("value", r"(get|set)_value\w*"),
("grid", r"get_grid_\w+"),
)


class Template:
"""Create template BMI implementations."""

def __init__(self, name: str):
self._name = name
self._funcs = dict(inspect.getmembers(Bmi, inspect.isfunction))

def render(self) -> str:
funcs = dict(inspect.getmembers(Bmi, inspect.isfunction))

names = sort_methods(frozenset(funcs))

self._funcs = OrderedDict(
(name, funcs.pop(name)) for name in names
) | OrderedDict(sorted(funcs.items()))

def render(self, with_docstring: bool = True) -> str:
"""Render a module that defines a class implementing a Bmi."""
prefix = f"""\
from __future__ import annotations
Expand All @@ -30,13 +51,15 @@ def render(self) -> str:
class {self._name}(Bmi):
"""
return prefix + (os.linesep * 2).join(
[self._render_func(name) for name in sorted(self._funcs)]
[
self._render_func(name, with_docstring=with_docstring)
for name in self._funcs
]
)

def _render_func(self, name: str) -> str:
def _render_func(self, name: str, with_docstring: bool = True) -> str:
annotations = inspect.get_annotations(self._funcs[name])
signature = inspect.signature(self._funcs[name], eval_str=False)

docstring = textwrap.indent(
'"""' + dedent_docstring(self._funcs[name].__doc__) + '"""', " "
)
Expand All @@ -47,14 +70,32 @@ def _render_func(self, name: str) -> str:
tuple(signature.parameters),
annotations,
width=84,
),
docstring,
f" raise NotImplementedError({name!r})".replace("'", '"'),
)
]
parts.append(docstring) if with_docstring else None
parts.append(f" raise NotImplementedError({name!r})".replace("'", '"'))

return textwrap.indent(os.linesep.join(parts), " ")


def sort_methods(funcs: frozenset[str]) -> list[str]:
"""Sort methods by group type."""
unmatched = set(funcs)
matched = defaultdict(set)

for group, regex in GROUPS:
pattern = re.compile(regex)

matched[group] = {name for name in unmatched if pattern.match(name)}
unmatched -= matched[group]

ordered = []
for group, _ in GROUPS:
ordered.extend(sorted(matched[group]))

return ordered + sorted(unmatched)


def dedent_docstring(text: str | None, tabsize: int = 4) -> str:
"""Dedent a docstring, ignoring indentation of the first line.
Expand Down
14 changes: 14 additions & 0 deletions tests/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,17 @@ def test_cli_with_hints(capsys):
@pytest.mark.parametrize("bad_name", ["True", "0Bmi"])
def test_cli_with_bad_class_name(capsys, bad_name):
assert main([bad_name]) != 0


def test_cli_docstrings(capsys):
assert main(["MyBmiWithDocstrings", "--docstring"]) == 0
output_default = capsys.readouterr().out

assert main(["MyBmiWithDocstrings", "--docstring"]) == 0
output_with_docstrings = capsys.readouterr().out
assert output_with_docstrings == output_default

assert main(["MyBmiWithoutDocstrings", "--no-docstring"]) == 0
output_without_docstrings = capsys.readouterr().out

assert len(output_with_docstrings) > len(output_without_docstrings)

0 comments on commit b520307

Please sign in to comment.