diff --git a/torchx/specs/builders.py b/torchx/specs/builders.py index 88e4b85f3..fa6a21455 100644 --- a/torchx/specs/builders.py +++ b/torchx/specs/builders.py @@ -7,9 +7,11 @@ # pyre-strict import argparse +import dataclasses import inspect import os from argparse import Namespace +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Mapping, Optional, Union from torchx.specs.api import BindMount, MountType, VolumeMount @@ -24,12 +26,74 @@ def _create_args_parser( cmpnt_defaults: Optional[Dict[str, str]] = None, config: Optional[Dict[str, Any]] = None, ) -> argparse.ArgumentParser: - parameters = inspect.signature(cmpnt_fn).parameters + parameters = _get_params_from_component_signature(cmpnt_fn).parameters return _create_args_parser_from_parameters( cmpnt_fn, parameters, cmpnt_defaults, config ) +@dataclass +class SignatureInfo: + parameters: Mapping[str, inspect.Parameter] + dataclass_type: type[object] | None + + +def _get_params_from_component_signature( + cmpnt_fn: Callable[..., AppDef] +) -> SignatureInfo: + parameters = inspect.signature(cmpnt_fn).parameters + dataclass_type = _maybe_get_dataclass_type(parameters) + if dataclass_type is not None: + parameters = _flatten_dataclass_params(parameters) + return SignatureInfo(parameters, dataclass_type) + + +def _maybe_get_dataclass_type( + parameters: Mapping[str, inspect.Parameter] +) -> type[object] | None: + if len(parameters) not in (1, 2): + # only support a single dataclass or a single dataclass followed by a vararg + return None + params = list(parameters.values()) + first_param_type = params[0].annotation + is_first_param_dataclass = dataclasses.is_dataclass( + first_param_type + ) and isinstance(first_param_type, type) + if not is_first_param_dataclass: + return None + if len(params) == 1: + return first_param_type + if len(params) == 2 and params[1].kind == inspect.Parameter.VAR_POSITIONAL: + return first_param_type + return None + + +def _flatten_dataclass_params( + parameters: Mapping[str, inspect.Parameter] +) -> Mapping[str, inspect.Parameter]: + result = {} + + for param_name, param in parameters.items(): + param_type = param.annotation + if not dataclasses.is_dataclass(param_type): + result[param_name] = param + continue + else: + result.update( + { + f.name: inspect.Parameter( + f.name, + inspect._ParameterKind.KEYWORD_ONLY, + annotation=f.type, + default=f.default, + ) + for f in dataclasses.fields(param_type) + } + ) + + return result + + def _create_args_parser_from_parameters( cmpnt_fn: Callable[..., Any], # pyre-ignore[2] parameters: Mapping[str, inspect.Parameter], @@ -69,7 +133,7 @@ def __call__( ) for param_name, parameter in parameters.items(): - param_desc = args_desc[parameter.name] + param_desc = args_desc.get(parameter.name) args: Dict[str, Any] = { "help": param_desc, "type": get_argparse_param_type(parameter), @@ -147,10 +211,9 @@ def materialize_appdef( config: Optional[Dict[str, Any]] = None, ) -> AppDef: """ - Creates an application by running user defined ``app_fn``. + Creates an application by running a user-defined component function ``cmpnt_fn``. - ``app_fn`` has the following restrictions: - * Name must be ``app_fn`` + ``cmpnt_fn`` has the following restrictions: * All arguments should be annotated * Supported argument types: - primitive: int, str, float @@ -158,7 +221,17 @@ def materialize_appdef( - List[primitive] - Optional[Dict[primitive, primitive]] - Optional[List[primitive]] - * ``app_fn`` can define a vararg (*arg) at the end + + The arguments can also be passed as a single dataclass, e.g. + + @dataclass + class Args: + arg1: str + arg2: Dict[str, int] + + def cmpnt_fn(args: Args) -> AppDef: ... + + * ``cmpnt_fn`` can define a vararg (*arg) at the end (this also works if the first argument is a dataclass) * There should be a docstring for the function that defines All arguments in a google-style format * There can be default values for the function arguments. @@ -180,8 +253,9 @@ def materialize_appdef( parsed_args = parse_args(cmpnt_fn, cmpnt_args, cmpnt_defaults, config) - parameters = inspect.signature(cmpnt_fn).parameters - for param_name, parameter in parameters.items(): + signature_info = _get_params_from_component_signature(cmpnt_fn) + + for param_name, parameter in signature_info.parameters.items(): arg_value = getattr(parsed_args, param_name) parameter_type = parameter.annotation parameter_type = decode_optional(parameter_type) @@ -197,6 +271,9 @@ def materialize_appdef( if len(var_arg) > 0 and var_arg[0] == "--": var_arg = var_arg[1:] + if signature_info.dataclass_type is not None: + function_args = [signature_info.dataclass_type(**kwargs)] + kwargs = {} appdef = cmpnt_fn(*function_args, *var_arg, **kwargs) if not isinstance(appdef, AppDef): raise TypeError( diff --git a/torchx/specs/test/builders_test.py b/torchx/specs/test/builders_test.py index a733f5502..da178577b 100644 --- a/torchx/specs/test/builders_test.py +++ b/torchx/specs/test/builders_test.py @@ -9,7 +9,7 @@ import argparse import sys import unittest -from dataclasses import asdict +from dataclasses import asdict, dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from unittest.mock import patch @@ -130,6 +130,34 @@ def example_test_complex_fn( return AppDef(app_name, roles) +@dataclass +class ComplexArgs: + app_name: str + containers: List[str] + roles_scripts: Dict[str, str] + num_cpus: Optional[List[int]] = None + num_gpus: Optional[Dict[str, int]] = None + nnodes: int = 4 + first_arg: Optional[str] = None + nested_arg: Optional[Dict[str, List[str]]] = None + + +def example_test_complex_fn_dataclass_arg( + args: ComplexArgs, *roles_args: str +) -> AppDef: + return example_test_complex_fn( + args.app_name, + args.containers, + args.roles_scripts, + args.num_cpus, + args.num_gpus, + args.nnodes, + args.first_arg, + args.nested_arg, + *roles_args, + ) + + _TEST_VAR_ARGS: Optional[Tuple[object, ...]] = None @@ -292,6 +320,12 @@ def test_load_from_fn_complex_all_args(self) -> None: actual_app = materialize_appdef(example_test_complex_fn, app_args) self.assert_apps(expected_app, actual_app) + def test_load_from_fn_complex_all_args_dataclass(self) -> None: + expected_app = self._get_expected_app_with_all_args() + app_args = self._get_app_args() + actual_app = materialize_appdef(example_test_complex_fn_dataclass_arg, app_args) + self.assert_apps(expected_app, actual_app) + def test_required_args(self) -> None: with patch.object(sys, "exit") as exit_mock: try: