diff --git a/CHANGELOG.md b/CHANGELOG.md index 2bff532..eddad99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## 0.25 + +### 0.25.0 + +- feat: Add `Arg.propagate`. + ## 0.24 ### 0.24.3 diff --git a/docs/source/arg.md b/docs/source/arg.md index 7f6258a..6844cdc 100644 --- a/docs/source/arg.md +++ b/docs/source/arg.md @@ -45,6 +45,7 @@ can be used and are interpreted to handle different kinds of CLI input. :noindex: ``` +(arg-action)= ## `Arg.action` Obliquely referenced through other `Arg` options like `count`, every `Arg` has a @@ -210,6 +211,7 @@ supplied value at the CLI level. :noindex: ``` +(arg-group)= ## `Arg.group`: Groups (and Mutual Exclusion) `Arg(group=...)` can be used to customize the way arguments/options are grouped @@ -567,3 +569,95 @@ complex types into the CLI structure; but it in essentially opposite uses. `unpa a complex type from a single CLI argument, whereas a "destructured" argument composes together multiple CLI arguments into one object without requiring a separate command. ``` + +## `Arg.propagate` + +Argument propagation is a way of making higher-level arguments available during the parsing of child +subcommands. When an argument is marked as `Arg(propagate=True)`, access to that argument will +be made available in all **child** subcommands, while still recording the value itself to the +object on which it was defined. + +```python +from __future__ import annotations +from dataclasses import dataclass +from typing import Annotated +import cappa + +@dataclass +class Main: + file: Annotated[str, cappa.Arg(long=True, propagate=True)] + subcommand: cappa.Subcommands[One | Two | None] = None + +@dataclass +class One: + ... + +@dataclass +class Two: + subcommand: cappa.Subcommands[Nested | None] = None + +@dataclass +class Nested: + ... + +print(cappa.parse(Main)) +``` + +Given the above example, all of the following would be valid, and produce `Main(file='config.txt', ...)`: + +- `main.py --file config.txt` +- `main.py one --file config.txt` +- `main.py two --file config.txt` +- `main.py two nested --file config.txt` + + +If defined on a top-level command object (like above), that argument will effectively +be available globally within the CLI, again while actually propagating the value +back to the place at which it was defined. + +However if the propagated argument is **not** defined at the top-level, it will +not propagate "upwards" to parent commands; only downward to child subcommands. + +```{note} +`Arg.propagate` is not currently enabled/allowed for positional arguments (file an issue if this +is a problem for you!) largely because it's not clear that the feature makes any sense except +on (particularly optional) options. + +`Arg.propagate` is not implemented in the `argparse` backend. +``` + +### Propagated Arg Help + +By default propagated arguments are added to child subcommand's help as though the argument +was defined like any other argument. + +If you want propagated arguments categorically separated from normal arguments, you can +assign them a distinct [group](#arg-group), which will cause them to be displayed separately. + +For example: +```python +group = cappa.Group(name="Global", section=1) + +@dataclass +class Command: + other1: Annotated[int, cappa.Arg(long=True)] = 1 + other2: Annotated[int, cappa.Arg(long=True)] = 1 + foo: Annotated[int, cappa.Arg(long=True, propagate=True, group=group)] = 1 + bar: Annotated[str, cappa.Arg(long=True, propagate=True, group=group)] = 1 +``` + +Would yield: + +``` + Options + [--other1 OTHER1] (Default: 1) + [--other2 OTHER2] (Default: 1) + + Global + [--foo FOO] (Default: 1) + [--bar BAR] (Default: 1) +``` + +Note, this is no different from use of `Arg.group` in any other context, except in that +the argument only exists at the declaration point, so any grouping configuration will +also propagate down into the way child commands render those arguments as well. diff --git a/pyproject.toml b/pyproject.toml index 64b5998..14e7289 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "cappa" -version = "0.24.3" +version = "0.25.0" description = "Declarative CLI argument parser." urls = {repository = "https://github.com/dancardin/cappa"} diff --git a/src/cappa/arg.py b/src/cappa/arg.py index b130112..2b8f89d 100644 --- a/src/cappa/arg.py +++ b/src/cappa/arg.py @@ -53,7 +53,7 @@ class ArgAction(enum.Enum): completion = "completion" @classmethod - def value_actions(cls) -> typing.Set[ArgAction]: + def meta_actions(cls) -> typing.Set[ArgAction]: return {cls.help, cls.version, cls.completion} @classmethod @@ -106,7 +106,7 @@ def key(self): return (self.section, self.order, self.name, self.exclusive) -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class Arg(typing.Generic[T]): """Describe a CLI argument. @@ -161,8 +161,13 @@ class Arg(typing.Generic[T]): has_value: Whether the argument has a value that should be saved back to the destination type. For most `Arg`, this will default to `True`, however `--help` is an example of an `Arg` for which it is false. + propagate: Specifies that an argument can be matched to all child. Global arguments only + propagate down. When used at the top-level, in effect it creates a "global" argument. """ + def __hash__(self): + return id(self) + value_name: str | EmptyType = Empty short: bool | str | list[str] | None = False long: bool | str | list[str] | None = False @@ -182,6 +187,7 @@ class Arg(typing.Generic[T]): field_name: str | EmptyType = Empty deprecated: bool | str = False show_default: bool = True + propagate: bool = False destructured: Destructured | None = None has_value: bool | None = None @@ -269,6 +275,11 @@ def normalize( value_name = infer_value_name(self, field_name, num_args) has_value = infer_has_value(self, action) + if self.propagate and not short and not long: + raise ValueError( + "`Arg.propagate` requires a non-positional named option (`short` or `long`)." + ) + return dataclasses.replace( self, default=default, @@ -306,6 +317,10 @@ def names_str(self, delimiter: str = ", ", *, n=0) -> str: return typing.cast(str, self.value_name) + @cached_property + def is_option(self) -> bool: + return bool(self.short or self.long) + def verify_type_compatibility(arg: Arg, field_name: str, type_view: TypeView): """Verify classes of annotations are compatible with one another. @@ -671,7 +686,7 @@ def infer_has_value(arg: Arg, action: ArgActionType): if arg.has_value is not None: return arg.has_value - if isinstance(action, ArgAction) and action in ArgAction.value_actions(): + if isinstance(action, ArgAction) and action in ArgAction.meta_actions(): return False return True diff --git a/src/cappa/argparse.py b/src/cappa/argparse.py index 4dd7760..51212d9 100644 --- a/src/cappa/argparse.py +++ b/src/cappa/argparse.py @@ -216,6 +216,9 @@ def add_argument( dest_prefix="", **extra_kwargs, ): + if arg.propagate: + raise ValueError("The argparse backend does not support the `Arg.propagate`.") + names: list[str] = [] if arg.short: short = assert_type(arg.short, list) diff --git a/src/cappa/command.py b/src/cappa/command.py index bad2a33..48a11b1 100644 --- a/src/cappa/command.py +++ b/src/cappa/command.py @@ -5,6 +5,8 @@ import typing from collections.abc import Callable +from type_lens.type_view import TypeView + from cappa.arg import Arg, Group from cappa.class_inspect import fields as get_fields from cappa.class_inspect import get_command, get_command_capable_object @@ -14,6 +16,7 @@ from cappa.output import Exit, Output, prompt_types from cappa.subcommand import Subcommand from cappa.type_view import CallableView, Empty +from cappa.typing import assert_type T = typing.TypeVar("T") @@ -64,6 +67,8 @@ class Command(typing.Generic[T]): cmd_cls: type[T] arguments: list[Arg | Subcommand] = dataclasses.field(default_factory=list) + propagated_arguments: list[Arg] = dataclasses.field(default_factory=list) + name: str | None = None help: str | None = None description: str | None = None @@ -110,7 +115,9 @@ def real_name(self) -> str: return re.sub(r"(? Command[T]: + def collect( + cls, command: Command[T], propagated_arguments: list[Arg] | None = None + ) -> Command[T]: kwargs: CommandArgs = {} help_text = ClassHelpText.collect(command.cmd_cls) @@ -124,33 +131,47 @@ def collect(cls, command: Command[T]) -> Command[T]: fields = get_fields(command.cmd_cls) function_view = CallableView.from_callable(command.cmd_cls, include_extras=True) + propagated_arguments = propagated_arguments or [] + + arguments = [] + raw_subcommands: list[tuple[Subcommand, TypeView | None, str | None]] = [] if command.arguments: param_by_name = {p.name: p for p in function_view.parameters} - arguments: list[Arg | Subcommand] = [ - a.normalize( - type_view=param_by_name[typing.cast(str, a.field_name)].type_view - if a.field_name in param_by_name - else None, - default_short=command.default_short, - default_long=command.default_long, - ) - if isinstance(a, Arg) - else a.normalize() - for a in command.arguments - ] - else: - arguments = [] + for arg in command.arguments: + arg_help = help_text.args.get(assert_type(arg.field_name, str)) + if isinstance(arg, Arg): + type_view = ( + param_by_name[typing.cast(str, arg.field_name)].type_view + if arg.field_name in param_by_name + else None + ) + arguments.append( + arg.normalize( + type_view=type_view, + default_short=command.default_short, + default_long=command.default_long, + fallback_help=arg_help, + ) + ) + else: + raw_subcommands.append((arg, None, None)) + else: for field, param_view in zip(fields, function_view.parameters): arg_help = help_text.args.get(param_view.name) - maybe_subcommand = Subcommand.collect( + maybe_subcommand = Subcommand.detect( field, param_view.type_view, - help_formatter=command.help_formatter, ) if maybe_subcommand: - arguments.append(maybe_subcommand) + raw_subcommands.append( + ( + maybe_subcommand, + param_view.type_view, + field.name, + ) + ) else: arg_defs: list[Arg] = Arg.collect( field, @@ -159,13 +180,30 @@ def collect(cls, command: Command[T]) -> Command[T]: default_short=command.default_short, default_long=command.default_long, ) - arguments.extend(arg_defs) + propagating_arguments = [ + *propagated_arguments, + *(arg for arg in arguments if arg.propagate), + ] + subcommands = [ + subcommand.normalize( + type_view, + field_name, + help_formatter=command.help_formatter, + propagated_arguments=propagating_arguments, + ) + for subcommand, type_view, field_name in raw_subcommands + ] + check_group_identity(arguments) - kwargs["arguments"] = arguments + kwargs["arguments"] = [*arguments, *subcommands] - return dataclasses.replace(command, **kwargs) + return dataclasses.replace( + command, + **kwargs, + propagated_arguments=propagated_arguments, + ) @classmethod def parse_command( @@ -245,6 +283,31 @@ def value_arguments(self): yield arg + @property + def all_arguments(self) -> typing.Iterable[Arg | Subcommand]: + for arg in self.arguments: + yield arg + + for arg in self.propagated_arguments: + yield arg + + @property + def options(self) -> typing.Iterable[Arg]: + for arg in self.arguments: + if isinstance(arg, Arg) and arg.is_option: + yield arg + + @property + def positional_arguments(self) -> typing.Iterable[Arg | Subcommand]: + for arg in self.arguments: + if ( + isinstance(arg, Arg) + and not arg.short + and not arg.long + and not arg.destructured + ) or isinstance(arg, Subcommand): + yield arg + def add_meta_actions( self, help: Arg | None = None, @@ -283,13 +346,10 @@ class HasCommand(typing.Generic[H], typing.Protocol): __cappa__: typing.ClassVar[Command] -def check_group_identity(args: list[Arg | Subcommand]): +def check_group_identity(args: list[Arg]): group_identity: dict[str, Group] = {} for arg in args: - if isinstance(arg, Subcommand): - continue - assert isinstance(arg.group, Group) name = typing.cast(str, arg.group.name) diff --git a/src/cappa/destructure.py b/src/cappa/destructure.py index 44d0687..3306b7a 100644 --- a/src/cappa/destructure.py +++ b/src/cappa/destructure.py @@ -1,13 +1,13 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, replace from type_lens import TypeView from cappa.arg import Arg, ArgActionType from cappa.invoke import fulfill_deps from cappa.output import Output -from cappa.parser import ParseContext, Value, determine_action_handler +from cappa.parser import ParseContext, ParseState, Value, determine_action_handler from cappa.typing import assert_type @@ -24,7 +24,7 @@ def destructure(arg: Arg, type_view: TypeView): command: Command = Command.get(type_view.annotation) virtual_args = Command.collect(command).arguments - arg.parse = lambda v: command.cmd_cls(**v) + arg = replace(arg, parse=lambda v: command.cmd_cls(**v)) result = [arg] for virtual_arg in virtual_args: @@ -34,8 +34,11 @@ def destructure(arg: Arg, type_view: TypeView): ) assert virtual_arg.action - virtual_arg.action = restructure(arg, virtual_arg.action) - virtual_arg.has_value = False + virtual_arg = replace( + virtual_arg, + action=restructure(arg, virtual_arg.action), + has_value=False, + ) result.append(virtual_arg) @@ -45,13 +48,15 @@ def destructure(arg: Arg, type_view: TypeView): def restructure(root_arg: Arg, action: ArgActionType): action_handler = determine_action_handler(action) - def restructure_action(context: ParseContext, arg: Arg, value: Value): + def restructure_action( + parse_state: ParseState, context: ParseContext, arg: Arg, value: Value + ): root_field_name = assert_type(root_arg.field_name, str) result = context.result.setdefault(root_field_name, {}) fulfilled_deps: dict = { - Command: context.command, - Output: context.output, + Command: parse_state.current_command, + Output: parse_state.output, ParseContext: context, Arg: arg, Value: value, diff --git a/src/cappa/help.py b/src/cappa/help.py index df60aab..8f08c75 100644 --- a/src/cappa/help.py +++ b/src/cappa/help.py @@ -53,7 +53,7 @@ def create_version_arg(version: str | Arg | None = None) -> Arg | None: ) if version.long is True: - version.long = "--version" + version = replace(version, long="--version") return version.normalize( action=ArgAction.version, field_name="version", default=None @@ -211,7 +211,7 @@ def by_group(arg: Arg | Subcommand): group = assert_type(arg.group, Group) return (group.name, group.exclusive) - sorted_args = sorted(command.arguments, key=by_group_key) + sorted_args = sorted(command.all_arguments, key=by_group_key) return [ (g, [a for a in args if include_hidden or not a.hidden]) for g, args in groupby(sorted_args, key=by_group) @@ -229,16 +229,15 @@ def add_short_args(prog: str, arg_groups: list[ArgGroup]) -> str: def format_arg_name(arg: Arg | Subcommand, delimiter, *, n=0) -> str: if isinstance(arg, Arg): - is_option = arg.short or arg.long has_value = not ArgAction.is_non_value_consuming(arg.action) arg_names = arg.names_str(delimiter, n=n) - if not is_option: + if not arg.is_option: arg_names = arg_names.upper() text = f"[cappa.arg]{arg_names}[/cappa.arg]" - if is_option and has_value: + if arg.is_option and has_value: name = typing.cast(str, arg.value_name).upper() text = f"{text} [cappa.arg.name]{name}[/cappa.arg.name]" diff --git a/src/cappa/parser.py b/src/cappa/parser.py index 25b5f13..8bfc7c7 100644 --- a/src/cappa/parser.py +++ b/src/cappa/parser.py @@ -3,6 +3,7 @@ import dataclasses import typing from collections import deque +from functools import cached_property from cappa.arg import Arg, ArgAction, ArgActionType, Group from cappa.command import Command, Subcommand @@ -34,9 +35,8 @@ class HelpAction(RuntimeError): command_name: str @classmethod - def from_context(cls, context: ParseContext, command: Command): - name = " ".join(c.real_name() for c in context.command_stack) - raise cls(command, name) + def from_parse_state(cls, parse_state: ParseState, command: Command): + raise cls(command, parse_state.prog) @dataclasses.dataclass @@ -71,34 +71,34 @@ def backend( prog: str, provide_completions: bool = False, ) -> tuple[typing.Any, Command[T], dict[str, typing.Any]]: - args = RawArg.collect(argv, provide_completions=provide_completions) - - context = ParseContext.from_command(args, [command], output) - context.provide_completions = provide_completions + parse_state = ParseState.from_command( + argv, command, output=output, provide_completions=provide_completions + ) + context = ParseContext.from_command(parse_state.current_command) try: try: - parse(context) + parse(parse_state, context) except HelpAction as e: raise HelpExit( e.command.help_formatter(e.command, e.command_name), code=0, - prog=context.prog, + prog=parse_state.prog, ) except VersionAction as e: raise Exit( - typing.cast(str, e.version.value_name), code=0, prog=context.prog + typing.cast(str, e.version.value_name), code=0, prog=parse_state.prog ) except BadArgumentError as e: - if context.provide_completions and e.arg: + if parse_state.provide_completions and e.arg: completions = e.arg.completion(e.value) if e.arg.completion else [] raise CompletionAction(*completions) - raise Exit(str(e), code=2, prog=context.prog, command=e.command) + raise Exit(str(e), code=2, prog=parse_state.prog, command=e.command) except CompletionAction as e: from cappa.completion.base import execute, format_completions - if context.provide_completions: + if provide_completions: completions = format_completions(*e.completions) raise Exit(completions, code=0) @@ -107,59 +107,100 @@ def backend( if provide_completions: raise Exit(code=0) - return (context, context.command_stack[-1] or command, context.result) + return (parse_state, parse_state.current_command or command, context.result) @dataclasses.dataclass -class ParseContext: - remaining_args: deque[RawArg | RawOption] - options: dict[str, Arg] - arguments: deque[Arg | Subcommand] - missing_options: set[str] +class ParseState: + """The overall state of the argument parse.""" + remaining_args: deque[RawArg | RawOption] + command_stack: list[Command] output: Output + provide_completions: bool = False - consumed_args: list[RawArg | RawOption] = dataclasses.field(default_factory=list) + @classmethod + def from_command( + cls, + argv: list[str], + command: Command, + output: Output, + provide_completions: bool = False, + ): + args = RawArg.collect(argv, provide_completions=provide_completions) + return cls( + args, + command_stack=[command], + output=output, + provide_completions=provide_completions, + ) + + @property + def current_command(self): + return self.command_stack[-1] + + @property + def prog(self): + return " ".join(c.real_name() for c in self.command_stack) + + def push_command(self, command: Command): + self.command_stack.append(command) + + def push_arg(self, arg: RawArg): + self.remaining_args.appendleft(arg) + + def has_values(self) -> bool: + return bool(self.remaining_args) + + def peek_value(self): + if not self.remaining_args: + return None + return self.remaining_args[0] + + def next_value(self): + return self.remaining_args.popleft() + + +@dataclasses.dataclass +class ParseContext: + """The parsing context specific to a command.""" + + command: Command + arguments: deque[Arg | Subcommand] + missing_options: set[str] + options: dict[str, Arg] + propagated_options: set[str] + parent_context: ParseContext | None = None exclusive_args: dict[str, Arg] = dataclasses.field(default_factory=dict) result: dict[str, typing.Any] = dataclasses.field(default_factory=dict) - command_stack: list[Command] = dataclasses.field(default_factory=list) - - provide_completions: bool = False @classmethod def from_command( cls, - args: deque[RawArg | RawOption], - command_stack: list[Command], - output: Output, + command: Command, + parent_context: ParseContext | None = None, ) -> ParseContext: - command = command_stack[-1] - options, missing_options = cls.collect_options(command) - arguments = deque(cls.collect_arguments(command)) + options, missing_options, propagated_options = cls.collect_options(command) + arguments = deque(command.positional_arguments) return cls( - args, - options, - arguments, - output=output, + command=command, + parent_context=parent_context, + options=options, + propagated_options=propagated_options, + arguments=arguments, missing_options=missing_options, - command_stack=command_stack, ) @staticmethod - def collect_options(command: Command) -> tuple[dict[str, Arg], set[str]]: + def collect_options( + command: Command, + ) -> tuple[dict[str, Arg], set[str], set[str]]: result = {} unique_names = set() - for arg in command.arguments: - field_name = typing.cast(str, arg.field_name) - if not isinstance(arg, Arg): - continue - - if arg.short or arg.long: - if arg.action not in ArgAction.value_actions(): - unique_names.add(field_name) - result[field_name] = arg + propagated_options = set() + def add_option_names(arg: Arg): for opts in (arg.short, arg.long): if not opts: continue @@ -170,45 +211,65 @@ def collect_options(command: Command) -> tuple[dict[str, Arg], set[str]]: result[key] = arg - return result, unique_names + for arg in command.options: + field_name = typing.cast(str, arg.field_name) - @staticmethod - def collect_arguments(command: Command) -> list[Arg | Subcommand]: - result = [] - for arg in command.arguments: - if ( - isinstance(arg, Arg) - and not arg.short - and not arg.long - and not arg.destructured - ) or isinstance(arg, Subcommand): - result.append(arg) - return result + if arg.action not in ArgAction.meta_actions(): + unique_names.add(field_name) + result[field_name] = arg + add_option_names(arg) - @property - def prog(self): - return " ".join(c.real_name() for c in self.command_stack) + for arg in command.propagated_arguments: + field_name = typing.cast(str, arg.field_name) - @property - def command(self): - return self.command_stack[-1] + if field_name in result: + continue - def has_values(self) -> bool: - return bool(self.remaining_args) + propagated_options.add(field_name) + result[field_name] = arg + add_option_names(arg) - def peek_value(self): - if not self.remaining_args: - return None - return self.remaining_args[0] + return result, unique_names, propagated_options - def next_value(self): - arg = self.remaining_args.popleft() - self.consumed_args.append(arg) - return arg + @cached_property + def propagated_context(self) -> dict[str, ParseContext]: + parent_context = ( + self.parent_context.propagated_context if self.parent_context else {} + ) + self_options = { + assert_type(o.field_name, str): o + for o in self.command.options + if o.propagate + } + self_context = dict.fromkeys(self_options, self) + return {**parent_context, **self_context} def next_argument(self): return self.arguments.popleft() + def set_result( + self, + field_name: str, + value: typing.Any, + option: RawOption | None = None, + has_value: bool = True, + ): + context = self + if option: + if field_name in self.propagated_options: + context = self.propagated_context[field_name] + + if field_name in context.missing_options: + context.missing_options.remove(field_name) + + if has_value: + context.result[field_name] = value + + def push(self, command: Command, name: str) -> ParseContext: + nested_context = ParseContext.from_command(command, parent_context=self) + nested_context.result["__name__"] = name + return nested_context + @dataclasses.dataclass class RawArg: @@ -276,19 +337,19 @@ def from_str(cls, arg: str) -> RawOption: return cls(name=name, is_long=is_long, value=value) -def parse(context: ParseContext) -> None: +def parse(parse_state: ParseState, context: ParseContext) -> None: while True: - while isinstance(context.peek_value(), RawOption): - arg = typing.cast(RawOption, context.next_value()) + while isinstance(parse_state.peek_value(), RawOption): + arg = typing.cast(RawOption, parse_state.next_value()) if arg.is_long: - parse_option(context, arg) + parse_option(parse_state, context, arg) else: - parse_short_option(context, arg) + parse_short_option(parse_state, context, arg) - parse_args(context) + parse_args(parse_state, context) - if not context.has_values(): + if not parse_state.has_values(): break # Options are not explicitly iterated over because they can occur multiple times non-contiguouesly. @@ -303,18 +364,20 @@ def parse(context: ParseContext) -> None: raise BadArgumentError( f"The following arguments are required: {names}", value="", - command=context.command, + command=parse_state.current_command, arg=required_missing_options[0], ) -def parse_option(context: ParseContext, raw: RawOption) -> None: +def parse_option( + parse_state: ParseState, context: ParseContext, raw: RawOption +) -> None: if raw.name not in context.options: possible_values = [ name for name in context.options if name.startswith(raw.name) ] - if context.provide_completions: + if parse_state.provide_completions: options = [ Completion(option, help=context.options[option].help) for option in possible_values @@ -325,27 +388,31 @@ def parse_option(context: ParseContext, raw: RawOption) -> None: if possible_values: message += f" (Did you mean: {', '.join(possible_values)})" - raise BadArgumentError(message, value=raw.name, command=context.command) + raise BadArgumentError( + message, value=raw.name, command=parse_state.current_command + ) arg = context.options[raw.name] - consume_arg(context, arg, raw) + consume_arg(parse_state, context, arg, raw) -def parse_short_option(context: ParseContext, arg: RawOption) -> None: - if arg.name == "-" and context.provide_completions: - return parse_option(context, arg) +def parse_short_option( + parse_state: ParseState, context: ParseContext, arg: RawOption +) -> None: + if arg.name == "-" and parse_state.provide_completions: + return parse_option(parse_state, context, arg) virtual_options, virtual_arg = generate_virtual_args(arg, context.options) *first_virtual_options, last_virtual_option = virtual_options for opt in first_virtual_options: - parse_option(context, opt) + parse_option(parse_state, context, opt) if virtual_arg: - context.remaining_args.appendleft(virtual_arg) + parse_state.push_arg(virtual_arg) - parse_option(context, last_virtual_option) + parse_option(parse_state, context, last_virtual_option) return None @@ -391,25 +458,25 @@ def generate_virtual_args( return (result, raw_arg) -def parse_args(context: ParseContext) -> None: +def parse_args(parse_state: ParseState, context: ParseContext) -> None: while context.arguments: - if isinstance(context.peek_value(), RawOption): + if isinstance(parse_state.peek_value(), RawOption): break arg = context.next_argument() if isinstance(arg, Subcommand): - consume_subcommand(context, arg) + consume_subcommand(parse_state, context, arg) else: - consume_arg(context, arg) + consume_arg(parse_state, context, arg) else: - value = context.peek_value() + value = parse_state.peek_value() if value is None or isinstance(value, RawOption): return raw_values = [] - while context.peek_value(): - next_val = context.next_value() + while parse_state.peek_value(): + next_val = parse_state.next_value() if not isinstance(next_val, RawArg): break raw_values.append(next_val.raw) @@ -417,13 +484,15 @@ def parse_args(context: ParseContext) -> None: raise BadArgumentError( f"Unrecognized arguments: {', '.join(raw_values)}", value=raw_values, - command=context.command, + command=parse_state.current_command, ) -def consume_subcommand(context: ParseContext, arg: Subcommand) -> typing.Any: +def consume_subcommand( + parse_state: ParseState, context: ParseContext, arg: Subcommand +) -> typing.Any: try: - value = context.next_value() + value = parse_state.next_value() except IndexError: if not arg.required: return @@ -431,7 +500,7 @@ def consume_subcommand(context: ParseContext, arg: Subcommand) -> typing.Any: raise BadArgumentError( f"A command is required: {{{format_subcommand_names(arg.names())}}}", value="", - command=context.command, + command=parse_state.current_command, arg=arg, ) @@ -445,31 +514,27 @@ def consume_subcommand(context: ParseContext, arg: Subcommand) -> typing.Any: raise BadArgumentError( message, value=value.raw, - command=context.command, + command=parse_state.current_command, arg=arg, ) command = arg.options[value.raw] - check_deprecated(context, command) - - context.command_stack.append(command) + check_deprecated(parse_state, command) - nested_context = ParseContext.from_command( - context.remaining_args, - command_stack=context.command_stack, - output=context.output, - ) - nested_context.provide_completions = context.provide_completions - nested_context.result["__name__"] = value.raw + parse_state.push_command(command) + nested_context = context.push(command, value.raw) - parse(nested_context) + parse(parse_state, nested_context) name = typing.cast(str, arg.field_name) context.result[name] = nested_context.result def consume_arg( - context: ParseContext, arg: Arg, option: RawOption | None = None + parse_state: ParseState, + context: ParseContext, + arg: Arg, + option: RawOption | None = None, ) -> typing.Any: field_name = typing.cast(str, arg.field_name) @@ -493,11 +558,11 @@ def consume_arg( if requires_values: result = [] while num_args: - if isinstance(context.peek_value(), RawOption): + if isinstance(parse_state.peek_value(), RawOption): break try: - next_val = typing.cast(RawArg, context.next_value()) + next_val = typing.cast(RawArg, parse_state.next_value()) except IndexError: break @@ -518,11 +583,11 @@ def consume_arg( raise BadArgumentError( f"Invalid choice: '{result}' (choose from {choices})", value=result, - command=context.command, + command=parse_state.current_command, arg=arg, ) - if context.provide_completions and not context.has_values(): + if parse_state.provide_completions and not parse_state.has_values(): if arg.completion: completions: list[Completion] | list[FileCompletion] = ( arg.completion(result) @@ -537,7 +602,7 @@ def consume_arg( raise BadArgumentError( f"Option '{arg.value_name}' requires an argument", value="", - command=context.command, + command=parse_state.current_command, arg=arg, ) else: @@ -551,7 +616,7 @@ def consume_arg( raise BadArgumentError( message, value=result, - command=context.command, + command=parse_state.current_command, arg=arg, ) @@ -565,21 +630,19 @@ def consume_arg( f"Argument '{arg.names_str('/')}' is not allowed with argument" f" '{exclusive_arg.names_str('/')}'", value=result, - command=context.command, + command=parse_state.current_command, arg=arg, ) context.exclusive_args[group_name] = arg - if option and field_name in context.missing_options: - context.missing_options.remove(field_name) - action_handler = determine_action_handler(arg.action) fulfilled_deps: dict = { - Command: context.command, - Output: context.output, + Command: parse_state.current_command, + Output: parse_state.output, ParseContext: context, + ParseState: parse_state, Arg: arg, Value: Value(result), } @@ -588,14 +651,14 @@ def consume_arg( kwargs = fulfill_deps(action_handler, fulfilled_deps).kwargs result = action_handler(**kwargs) - if arg.has_value: - context.result[field_name] = result - check_deprecated(context, arg, option) + context.set_result(field_name, result, option, assert_type(arg.has_value, bool)) + + check_deprecated(parse_state, arg, option) def check_deprecated( - context: ParseContext, arg: Arg | Command, option: RawOption | None = None + parse_state: ParseState, arg: Arg | Command, option: RawOption | None = None ) -> None: if not arg.deprecated: return @@ -615,7 +678,7 @@ def check_deprecated( if isinstance(arg.deprecated, str): message += f": {arg.deprecated}" - context.output.error(message) + parse_state.output.error(message) @dataclasses.dataclass @@ -655,7 +718,7 @@ def determine_action_handler(action: ArgActionType | None): process_options: dict[ArgAction, typing.Callable] = { - ArgAction.help: HelpAction.from_context, + ArgAction.help: HelpAction.from_parse_state, ArgAction.version: VersionAction.from_arg, ArgAction.completion: CompletionAction.from_value, ArgAction.set: store_set, diff --git a/src/cappa/subcommand.py b/src/cappa/subcommand.py index 9933d60..1ef1d05 100644 --- a/src/cappa/subcommand.py +++ b/src/cappa/subcommand.py @@ -12,6 +12,7 @@ from cappa.typing import T, assert_type, find_annotations if typing.TYPE_CHECKING: + from cappa.arg import Arg from cappa.command import Command from cappa.help import HelpFormatable @@ -48,33 +49,25 @@ class Subcommand: types: typing.Iterable[type] | EmptyType = Empty @classmethod - def collect( - cls, - field: Field, - type_view: TypeView, - help_formatter: HelpFormatable | None = None, - ) -> Subcommand | None: - subcommand = find_annotations(type_view, Subcommand) or None + def detect(cls, field: Field, type_view: TypeView) -> Subcommand | None: + subcommands = find_annotations(type_view, Subcommand) or None field_metadata = extract_dataclass_metadata(field, Subcommand) if field_metadata: - subcommand = field_metadata + subcommands = field_metadata - if not subcommand: + if not subcommands: return None - assert len(subcommand) == 1 - return subcommand[0].normalize( - type_view, - field_name=field.name, - help_formatter=help_formatter, - ) + assert len(subcommands) == 1 + return subcommands[0] def normalize( self, type_view: TypeView | None = None, field_name: str | None = None, help_formatter: HelpFormatable | None = None, + propagated_arguments: list[Arg] | None = None, ) -> Self: if type_view is None: type_view = TypeView(...) @@ -82,7 +75,12 @@ def normalize( field_name = field_name or assert_type(self.field_name, str) types = infer_types(self, type_view) required = infer_required(self, type_view) - options = infer_options(self, types, help_formatter=help_formatter) + options = infer_options( + self, + types, + help_formatter=help_formatter, + propagated_arguments=propagated_arguments, + ) group = infer_group(self) return dataclasses.replace( @@ -133,12 +131,15 @@ def infer_options( arg: Subcommand, types: typing.Iterable[type], help_formatter: HelpFormatable | None = None, + propagated_arguments: list[Arg] | None = None, ) -> dict[str, Command]: from cappa.command import Command if arg.options: return { - name: Command.collect(type_command) + name: Command.collect( + type_command, propagated_arguments=propagated_arguments + ) for name, type_command in arg.options.items() } @@ -146,7 +147,9 @@ def infer_options( for type_ in types: type_command: Command = Command.get(type_, help_formatter=help_formatter) type_name = type_command.real_name() - options[type_name] = Command.collect(type_command) + options[type_name] = Command.collect( + type_command, propagated_arguments=propagated_arguments + ) return options diff --git a/tests/arg/test_propagate.py b/tests/arg/test_propagate.py new file mode 100644 index 0000000..8312031 --- /dev/null +++ b/tests/arg/test_propagate.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import pytest +from typing_extensions import Annotated + +import cappa +from tests.utils import CapsysOutput, parse + + +@dataclass +class Command: + foo: Annotated[int, cappa.Arg(long=True, propagate=True, help="Everywhere")] = 1 + sub: cappa.Subcommands[One | Two | None] = None + + +@dataclass +class One: ... + + +@dataclass +class Two: ... + + +def test_propagated_arg(): + result = parse(Command) + assert result == Command(1, None) + + result = parse(Command, "--foo=4") + assert result == Command(4, None) + + result = parse(Command, "one", "--foo=5") + assert result == Command(5, One()) + + result = parse(Command, "two", "--foo=6") + assert result == Command(6, Two()) + + +def test_propagate_incompatible_with_argparse(): + with pytest.raises(ValueError): + parse(Command, backend=cappa.argparse.backend) + + +@dataclass +class Required: + foo: Annotated[int, cappa.Arg(long=True, propagate=True)] + sub: cappa.Subcommands[One | Two | None] = None + + +def test_required_propagated_arg(): + with pytest.raises(cappa.Exit): + parse(Required) + + result = parse(Required, "--foo=4") + assert result == Required(4, None) + + result = parse(Required, "one", "--foo=5") + assert result == Required(5, One()) + + result = parse(Required, "two", "--foo=6") + assert result == Required(6, Two()) + + +@dataclass +class ChildOverride: + foo: Annotated[int, cappa.Arg(long=True, propagate=True)] = 1 + sub: cappa.Subcommands[Child | None] = None + + +@dataclass +class Child: + foo: Annotated[int, cappa.Arg(long=True)] = 2 + + +def test_child_override(): + result = parse(ChildOverride) + assert result == ChildOverride(1, None) + + result = parse(ChildOverride, "--foo=4") + assert result == ChildOverride(4, None) + + result = parse(ChildOverride, "--foo=4", "child") + assert result == ChildOverride(4, Child(2)) + + result = parse(ChildOverride, "child") + assert result == ChildOverride(1, Child(2)) + + result = parse(ChildOverride, "child", "--foo=5") + assert result == ChildOverride(1, Child(5)) + + result = parse(ChildOverride, "--foo=3", "child", "--foo=6") + assert result == ChildOverride(3, Child(6)) + + +def test_propagate_requires_option(): + @dataclass + class ChildOverride: + foo: Annotated[int, cappa.Arg(propagate=True)] + + with pytest.raises(ValueError) as e: + parse(ChildOverride) + assert ( + str(e.value) + == "`Arg.propagate` requires a non-positional named option (`short` or `long`)." + ) + + +def test_help_contains_propagated_arg(capsys): + @dataclass + class Command: + foo: Annotated[int, cappa.Arg(long=True, propagate=True, help="Everywhere")] = 1 + bar: Annotated[int, cappa.Arg(long=True, help="Nowhere")] = 1 + sub: cappa.Subcommands[One | Two | None] = None + + with pytest.raises(cappa.Exit): + parse(Command, "--help") + output = CapsysOutput.from_capsys(capsys) + assert "foo" in output.stdout + assert "Everywhere" in output.stdout + assert "Nowhere" in output.stdout + + with pytest.raises(cappa.Exit): + parse(Command, "one", "--help") + output = CapsysOutput.from_capsys(capsys) + assert "foo" in output.stdout + assert "Everywhere" in output.stdout + assert "Nowhere" not in output.stdout + + with pytest.raises(cappa.Exit): + parse(Command, "two", "--help") + output = CapsysOutput.from_capsys(capsys) + assert "foo" in output.stdout + assert "Everywhere" in output.stdout + assert "Nowhere" not in output.stdout