Skip to content

support dataclasses for component args #1074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 85 additions & 8 deletions torchx/specs/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
# pyre-strict

import argparse
import dataclasses
import inspect
import os
from argparse import Namespace
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Mapping, Optional, Union

from torchx.specs.api import BindMount, MountType, VolumeMount
Expand All @@ -24,12 +26,74 @@ def _create_args_parser(
cmpnt_defaults: Optional[Dict[str, str]] = None,
config: Optional[Dict[str, Any]] = None,
) -> argparse.ArgumentParser:
parameters = inspect.signature(cmpnt_fn).parameters
parameters = _get_params_from_component_signature(cmpnt_fn).parameters
return _create_args_parser_from_parameters(
cmpnt_fn, parameters, cmpnt_defaults, config
)


@dataclass
class SignatureInfo:
parameters: Mapping[str, inspect.Parameter]
dataclass_type: type[object] | None


def _get_params_from_component_signature(
cmpnt_fn: Callable[..., AppDef]
) -> SignatureInfo:
parameters = inspect.signature(cmpnt_fn).parameters
dataclass_type = _maybe_get_dataclass_type(parameters)
if dataclass_type is not None:
parameters = _flatten_dataclass_params(parameters)
return SignatureInfo(parameters, dataclass_type)


def _maybe_get_dataclass_type(
parameters: Mapping[str, inspect.Parameter]
) -> type[object] | None:
if len(parameters) not in (1, 2):
# only support a single dataclass or a single dataclass followed by a vararg
return None
params = list(parameters.values())
first_param_type = params[0].annotation
is_first_param_dataclass = dataclasses.is_dataclass(
first_param_type
) and isinstance(first_param_type, type)
if not is_first_param_dataclass:
return None
if len(params) == 1:
return first_param_type
if len(params) == 2 and params[1].kind == inspect.Parameter.VAR_POSITIONAL:
return first_param_type
return None


def _flatten_dataclass_params(
parameters: Mapping[str, inspect.Parameter]
) -> Mapping[str, inspect.Parameter]:
result = {}

for param_name, param in parameters.items():
param_type = param.annotation
if not dataclasses.is_dataclass(param_type):
result[param_name] = param
continue
else:
result.update(
{
f.name: inspect.Parameter(
f.name,
inspect._ParameterKind.KEYWORD_ONLY,
annotation=f.type,
default=f.default,
)
for f in dataclasses.fields(param_type)
}
)

return result


def _create_args_parser_from_parameters(
cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
parameters: Mapping[str, inspect.Parameter],
Expand Down Expand Up @@ -69,7 +133,7 @@ def __call__(
)

for param_name, parameter in parameters.items():
param_desc = args_desc[parameter.name]
param_desc = args_desc.get(parameter.name)
args: Dict[str, Any] = {
"help": param_desc,
"type": get_argparse_param_type(parameter),
Expand Down Expand Up @@ -147,18 +211,27 @@ def materialize_appdef(
config: Optional[Dict[str, Any]] = None,
) -> AppDef:
"""
Creates an application by running user defined ``app_fn``.
Creates an application by running a user-defined component function ``cmpnt_fn``.

``app_fn`` has the following restrictions:
* Name must be ``app_fn``
``cmpnt_fn`` has the following restrictions:
* All arguments should be annotated
* Supported argument types:
- primitive: int, str, float
- Dict[primitive, primitive]
- List[primitive]
- Optional[Dict[primitive, primitive]]
- Optional[List[primitive]]
* ``app_fn`` can define a vararg (*arg) at the end

The arguments can also be passed as a single dataclass, e.g.

@dataclass
class Args:
arg1: str
arg2: Dict[str, int]

def cmpnt_fn(args: Args) -> AppDef: ...

* ``cmpnt_fn`` can define a vararg (*arg) at the end (this also works if the first argument is a dataclass)
* There should be a docstring for the function that defines
All arguments in a google-style format
* There can be default values for the function arguments.
Expand All @@ -180,8 +253,9 @@ def materialize_appdef(

parsed_args = parse_args(cmpnt_fn, cmpnt_args, cmpnt_defaults, config)

parameters = inspect.signature(cmpnt_fn).parameters
for param_name, parameter in parameters.items():
signature_info = _get_params_from_component_signature(cmpnt_fn)

for param_name, parameter in signature_info.parameters.items():
arg_value = getattr(parsed_args, param_name)
parameter_type = parameter.annotation
parameter_type = decode_optional(parameter_type)
Expand All @@ -197,6 +271,9 @@ def materialize_appdef(
if len(var_arg) > 0 and var_arg[0] == "--":
var_arg = var_arg[1:]

if signature_info.dataclass_type is not None:
function_args = [signature_info.dataclass_type(**kwargs)]
kwargs = {}
appdef = cmpnt_fn(*function_args, *var_arg, **kwargs)
if not isinstance(appdef, AppDef):
raise TypeError(
Expand Down
36 changes: 35 additions & 1 deletion torchx/specs/test/builders_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import argparse
import sys
import unittest
from dataclasses import asdict
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import patch
Expand Down Expand Up @@ -130,6 +130,34 @@ def example_test_complex_fn(
return AppDef(app_name, roles)


@dataclass
class ComplexArgs:
app_name: str
containers: List[str]
roles_scripts: Dict[str, str]
num_cpus: Optional[List[int]] = None
num_gpus: Optional[Dict[str, int]] = None
nnodes: int = 4
first_arg: Optional[str] = None
nested_arg: Optional[Dict[str, List[str]]] = None


def example_test_complex_fn_dataclass_arg(
args: ComplexArgs, *roles_args: str
) -> AppDef:
return example_test_complex_fn(
args.app_name,
args.containers,
args.roles_scripts,
args.num_cpus,
args.num_gpus,
args.nnodes,
args.first_arg,
args.nested_arg,
*roles_args,
)


_TEST_VAR_ARGS: Optional[Tuple[object, ...]] = None


Expand Down Expand Up @@ -292,6 +320,12 @@ def test_load_from_fn_complex_all_args(self) -> None:
actual_app = materialize_appdef(example_test_complex_fn, app_args)
self.assert_apps(expected_app, actual_app)

def test_load_from_fn_complex_all_args_dataclass(self) -> None:
expected_app = self._get_expected_app_with_all_args()
app_args = self._get_app_args()
actual_app = materialize_appdef(example_test_complex_fn_dataclass_arg, app_args)
self.assert_apps(expected_app, actual_app)

def test_required_args(self) -> None:
with patch.object(sys, "exit") as exit_mock:
try:
Expand Down
Loading