Skip to content

Commit

Permalink
Merge pull request #169 from DanCardin/dc/propagate
Browse files Browse the repository at this point in the history
feat: Implement propagated arguments.
  • Loading branch information
DanCardin authored Nov 13, 2024
2 parents 4785ce2 + 6a17604 commit 0c1b038
Show file tree
Hide file tree
Showing 11 changed files with 576 additions and 193 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## 0.25

### 0.25.0

- feat: Add `Arg.propagate`.

## 0.24

### 0.24.3
Expand Down
94 changes: 94 additions & 0 deletions docs/source/arg.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ can be used and are interpreted to handle different kinds of CLI input.
:noindex:
```

(arg-action)=
## `Arg.action`

Obliquely referenced through other `Arg` options like `count`, every `Arg` has a
Expand Down Expand Up @@ -210,6 +211,7 @@ supplied value at the CLI level.
:noindex:
```

(arg-group)=
## `Arg.group`: Groups (and Mutual Exclusion)

`Arg(group=...)` can be used to customize the way arguments/options are grouped
Expand Down Expand Up @@ -567,3 +569,95 @@ complex types into the CLI structure; but it in essentially opposite uses. `unpa
a complex type from a single CLI argument, whereas a "destructured" argument composes together multiple
CLI arguments into one object without requiring a separate command.
```

## `Arg.propagate`

Argument propagation is a way of making higher-level arguments available during the parsing of child
subcommands. When an argument is marked as `Arg(propagate=True)`, access to that argument will
be made available in all **child** subcommands, while still recording the value itself to the
object on which it was defined.

```python
from __future__ import annotations
from dataclasses import dataclass
from typing import Annotated
import cappa

@dataclass
class Main:
file: Annotated[str, cappa.Arg(long=True, propagate=True)]
subcommand: cappa.Subcommands[One | Two | None] = None

@dataclass
class One:
...

@dataclass
class Two:
subcommand: cappa.Subcommands[Nested | None] = None

@dataclass
class Nested:
...

print(cappa.parse(Main))
```

Given the above example, all of the following would be valid, and produce `Main(file='config.txt', ...)`:

- `main.py --file config.txt`
- `main.py one --file config.txt`
- `main.py two --file config.txt`
- `main.py two nested --file config.txt`


If defined on a top-level command object (like above), that argument will effectively
be available globally within the CLI, again while actually propagating the value
back to the place at which it was defined.

However if the propagated argument is **not** defined at the top-level, it will
not propagate "upwards" to parent commands; only downward to child subcommands.

```{note}
`Arg.propagate` is not currently enabled/allowed for positional arguments (file an issue if this
is a problem for you!) largely because it's not clear that the feature makes any sense except
on (particularly optional) options.

`Arg.propagate` is not implemented in the `argparse` backend.
```

### Propagated Arg Help

By default propagated arguments are added to child subcommand's help as though the argument
was defined like any other argument.

If you want propagated arguments categorically separated from normal arguments, you can
assign them a distinct [group](#arg-group), which will cause them to be displayed separately.

For example:
```python
group = cappa.Group(name="Global", section=1)

@dataclass
class Command:
other1: Annotated[int, cappa.Arg(long=True)] = 1
other2: Annotated[int, cappa.Arg(long=True)] = 1
foo: Annotated[int, cappa.Arg(long=True, propagate=True, group=group)] = 1
bar: Annotated[str, cappa.Arg(long=True, propagate=True, group=group)] = 1
```

Would yield:

```
Options
[--other1 OTHER1] (Default: 1)
[--other2 OTHER2] (Default: 1)

Global
[--foo FOO] (Default: 1)
[--bar BAR] (Default: 1)
```

Note, this is no different from use of `Arg.group` in any other context, except in that
the argument only exists at the declaration point, so any grouping configuration will
also propagate down into the way child commands render those arguments as well.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "cappa"
version = "0.24.3"
version = "0.25.0"
description = "Declarative CLI argument parser."

urls = {repository = "https://github.com/dancardin/cappa"}
Expand Down
21 changes: 18 additions & 3 deletions src/cappa/arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class ArgAction(enum.Enum):
completion = "completion"

@classmethod
def value_actions(cls) -> typing.Set[ArgAction]:
def meta_actions(cls) -> typing.Set[ArgAction]:
return {cls.help, cls.version, cls.completion}

@classmethod
Expand Down Expand Up @@ -106,7 +106,7 @@ def key(self):
return (self.section, self.order, self.name, self.exclusive)


@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class Arg(typing.Generic[T]):
"""Describe a CLI argument.
Expand Down Expand Up @@ -161,8 +161,13 @@ class Arg(typing.Generic[T]):
has_value: Whether the argument has a value that should be saved back to the destination
type. For most `Arg`, this will default to `True`, however `--help` is an example
of an `Arg` for which it is false.
propagate: Specifies that an argument can be matched to all child. Global arguments only
propagate down. When used at the top-level, in effect it creates a "global" argument.
"""

def __hash__(self):
return id(self)

value_name: str | EmptyType = Empty
short: bool | str | list[str] | None = False
long: bool | str | list[str] | None = False
Expand All @@ -182,6 +187,7 @@ class Arg(typing.Generic[T]):
field_name: str | EmptyType = Empty
deprecated: bool | str = False
show_default: bool = True
propagate: bool = False

destructured: Destructured | None = None
has_value: bool | None = None
Expand Down Expand Up @@ -269,6 +275,11 @@ def normalize(
value_name = infer_value_name(self, field_name, num_args)
has_value = infer_has_value(self, action)

if self.propagate and not short and not long:
raise ValueError(
"`Arg.propagate` requires a non-positional named option (`short` or `long`)."
)

return dataclasses.replace(
self,
default=default,
Expand Down Expand Up @@ -306,6 +317,10 @@ def names_str(self, delimiter: str = ", ", *, n=0) -> str:

return typing.cast(str, self.value_name)

@cached_property
def is_option(self) -> bool:
return bool(self.short or self.long)


def verify_type_compatibility(arg: Arg, field_name: str, type_view: TypeView):
"""Verify classes of annotations are compatible with one another.
Expand Down Expand Up @@ -671,7 +686,7 @@ def infer_has_value(arg: Arg, action: ArgActionType):
if arg.has_value is not None:
return arg.has_value

if isinstance(action, ArgAction) and action in ArgAction.value_actions():
if isinstance(action, ArgAction) and action in ArgAction.meta_actions():
return False

return True
Expand Down
3 changes: 3 additions & 0 deletions src/cappa/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ def add_argument(
dest_prefix="",
**extra_kwargs,
):
if arg.propagate:
raise ValueError("The argparse backend does not support the `Arg.propagate`.")

names: list[str] = []
if arg.short:
short = assert_type(arg.short, list)
Expand Down
110 changes: 85 additions & 25 deletions src/cappa/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import typing
from collections.abc import Callable

from type_lens.type_view import TypeView

from cappa.arg import Arg, Group
from cappa.class_inspect import fields as get_fields
from cappa.class_inspect import get_command, get_command_capable_object
Expand All @@ -14,6 +16,7 @@
from cappa.output import Exit, Output, prompt_types
from cappa.subcommand import Subcommand
from cappa.type_view import CallableView, Empty
from cappa.typing import assert_type

T = typing.TypeVar("T")

Expand Down Expand Up @@ -64,6 +67,8 @@ class Command(typing.Generic[T]):

cmd_cls: type[T]
arguments: list[Arg | Subcommand] = dataclasses.field(default_factory=list)
propagated_arguments: list[Arg] = dataclasses.field(default_factory=list)

name: str | None = None
help: str | None = None
description: str | None = None
Expand Down Expand Up @@ -110,7 +115,9 @@ def real_name(self) -> str:
return re.sub(r"(?<!^)(?=[A-Z])", "-", cls_name).lower()

@classmethod
def collect(cls, command: Command[T]) -> Command[T]:
def collect(
cls, command: Command[T], propagated_arguments: list[Arg] | None = None
) -> Command[T]:
kwargs: CommandArgs = {}

help_text = ClassHelpText.collect(command.cmd_cls)
Expand All @@ -124,33 +131,47 @@ def collect(cls, command: Command[T]) -> Command[T]:
fields = get_fields(command.cmd_cls)
function_view = CallableView.from_callable(command.cmd_cls, include_extras=True)

propagated_arguments = propagated_arguments or []

arguments = []
raw_subcommands: list[tuple[Subcommand, TypeView | None, str | None]] = []
if command.arguments:
param_by_name = {p.name: p for p in function_view.parameters}
arguments: list[Arg | Subcommand] = [
a.normalize(
type_view=param_by_name[typing.cast(str, a.field_name)].type_view
if a.field_name in param_by_name
else None,
default_short=command.default_short,
default_long=command.default_long,
)
if isinstance(a, Arg)
else a.normalize()
for a in command.arguments
]
else:
arguments = []
for arg in command.arguments:
arg_help = help_text.args.get(assert_type(arg.field_name, str))
if isinstance(arg, Arg):
type_view = (
param_by_name[typing.cast(str, arg.field_name)].type_view
if arg.field_name in param_by_name
else None
)
arguments.append(
arg.normalize(
type_view=type_view,
default_short=command.default_short,
default_long=command.default_long,
fallback_help=arg_help,
)
)
else:
raw_subcommands.append((arg, None, None))

else:
for field, param_view in zip(fields, function_view.parameters):
arg_help = help_text.args.get(param_view.name)

maybe_subcommand = Subcommand.collect(
maybe_subcommand = Subcommand.detect(
field,
param_view.type_view,
help_formatter=command.help_formatter,
)
if maybe_subcommand:
arguments.append(maybe_subcommand)
raw_subcommands.append(
(
maybe_subcommand,
param_view.type_view,
field.name,
)
)
else:
arg_defs: list[Arg] = Arg.collect(
field,
Expand All @@ -159,13 +180,30 @@ def collect(cls, command: Command[T]) -> Command[T]:
default_short=command.default_short,
default_long=command.default_long,
)

arguments.extend(arg_defs)

propagating_arguments = [
*propagated_arguments,
*(arg for arg in arguments if arg.propagate),
]
subcommands = [
subcommand.normalize(
type_view,
field_name,
help_formatter=command.help_formatter,
propagated_arguments=propagating_arguments,
)
for subcommand, type_view, field_name in raw_subcommands
]

check_group_identity(arguments)
kwargs["arguments"] = arguments
kwargs["arguments"] = [*arguments, *subcommands]

return dataclasses.replace(command, **kwargs)
return dataclasses.replace(
command,
**kwargs,
propagated_arguments=propagated_arguments,
)

@classmethod
def parse_command(
Expand Down Expand Up @@ -245,6 +283,31 @@ def value_arguments(self):

yield arg

@property
def all_arguments(self) -> typing.Iterable[Arg | Subcommand]:
for arg in self.arguments:
yield arg

for arg in self.propagated_arguments:
yield arg

@property
def options(self) -> typing.Iterable[Arg]:
for arg in self.arguments:
if isinstance(arg, Arg) and arg.is_option:
yield arg

@property
def positional_arguments(self) -> typing.Iterable[Arg | Subcommand]:
for arg in self.arguments:
if (
isinstance(arg, Arg)
and not arg.short
and not arg.long
and not arg.destructured
) or isinstance(arg, Subcommand):
yield arg

def add_meta_actions(
self,
help: Arg | None = None,
Expand Down Expand Up @@ -283,13 +346,10 @@ class HasCommand(typing.Generic[H], typing.Protocol):
__cappa__: typing.ClassVar[Command]


def check_group_identity(args: list[Arg | Subcommand]):
def check_group_identity(args: list[Arg]):
group_identity: dict[str, Group] = {}

for arg in args:
if isinstance(arg, Subcommand):
continue

assert isinstance(arg.group, Group)

name = typing.cast(str, arg.group.name)
Expand Down
Loading

0 comments on commit 0c1b038

Please sign in to comment.