Skip to content

Commit b21f500

Browse files
Add ParamSpec node
1 parent 7e43d17 commit b21f500

File tree

6 files changed

+85
-3
lines changed

6 files changed

+85
-3
lines changed

astroid/nodes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
NamedExpr,
7272
NodeNG,
7373
Nonlocal,
74+
ParamSpec,
7475
Pass,
7576
Pattern,
7677
Raise,
@@ -182,6 +183,7 @@
182183
NamedExpr,
183184
NodeNG,
184185
Nonlocal,
186+
ParamSpec,
185187
Pass,
186188
Pattern,
187189
Raise,
@@ -275,6 +277,7 @@
275277
"NamedExpr",
276278
"NodeNG",
277279
"Nonlocal",
280+
"ParamSpec",
278281
"Pass",
279282
"Position",
280283
"Raise",

astroid/nodes/as_string.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,10 @@ def visit_nonlocal(self, node) -> str:
433433
"""return an astroid.Nonlocal node as string"""
434434
return f"nonlocal {', '.join(node.names)}"
435435

436+
def visit_paramspec(self, node: nodes.ParamSpec) -> str:
437+
"""return an astroid.ParamSpec node as string"""
438+
return node.name
439+
436440
def visit_pass(self, node) -> str:
437441
"""return an astroid.Pass node as string"""
438442
return "pass"

astroid/nodes/node_classes.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2696,6 +2696,52 @@ def _infer_name(self, frame, name):
26962696
return name
26972697

26982698

2699+
class ParamSpec(_base_nodes.AssignTypeNode):
2700+
"""Class representing a :class:`ast.ParamSpec` node.
2701+
2702+
>>> import astroid
2703+
>>> node = astroid.extract_node('type Alias[**P] = Callable[P, int]')
2704+
>>> node.type_params[0]
2705+
<ParamSpec l.1 at 0x7f23b2e4e198>
2706+
"""
2707+
2708+
def __init__(
2709+
self,
2710+
lineno: int | None = None,
2711+
col_offset: int | None = None,
2712+
parent: NodeNG | None = None,
2713+
*,
2714+
end_lineno: int | None = None,
2715+
end_col_offset: int | None = None,
2716+
) -> None:
2717+
self.name: str
2718+
super().__init__(
2719+
lineno=lineno,
2720+
col_offset=col_offset,
2721+
end_lineno=end_lineno,
2722+
end_col_offset=end_col_offset,
2723+
parent=parent,
2724+
)
2725+
2726+
def postinit(self, name: str) -> None:
2727+
self.name = name
2728+
2729+
assigned_stmts: ClassVar[
2730+
Callable[
2731+
[
2732+
ParamSpec,
2733+
AssignName,
2734+
InferenceContext | None,
2735+
None,
2736+
],
2737+
Generator[NodeNG, None, None],
2738+
]
2739+
]
2740+
"""Returns the assigned statement (non inferred) according to the assignment type.
2741+
See astroid/protocols.py for actual implementation.
2742+
"""
2743+
2744+
26992745
class Pass(_base_nodes.NoChildrenNode, _base_nodes.Statement):
27002746
"""Class representing an :class:`ast.Pass` node.
27012747
@@ -3330,7 +3376,7 @@ def __init__(
33303376
end_lineno: int | None = None,
33313377
end_col_offset: int | None = None,
33323378
) -> None:
3333-
self.type_params: list[TypeVar]
3379+
self.type_params: list[TypeVar, ParamSpec]
33343380
self.value: NodeNG
33353381
super().__init__(
33363382
lineno=lineno,
@@ -3343,7 +3389,7 @@ def __init__(
33433389
def postinit(
33443390
self,
33453391
*,
3346-
type_params: list[TypeVar],
3392+
type_params: list[TypeVar, ParamSpec],
33473393
value: NodeNG,
33483394
) -> None:
33493395
self.type_params = type_params

astroid/rebuilder.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,12 @@ def visit(self, node: ast.Nonlocal, parent: NodeNG) -> nodes.Nonlocal:
384384
def visit(self, node: ast.Constant, parent: NodeNG) -> nodes.Const:
385385
...
386386

387+
if sys.version_info >= (3, 12):
388+
389+
@overload
390+
def visit(self, node: ast.ParamSpec, parent: NodeNG) -> nodes.ParamSpec:
391+
...
392+
387393
@overload
388394
def visit(self, node: ast.Pass, parent: NodeNG) -> nodes.Pass:
389395
...
@@ -1493,6 +1499,18 @@ def visit_constant(self, node: ast.Constant, parent: NodeNG) -> nodes.Const:
14931499
parent=parent,
14941500
)
14951501

1502+
def visit_paramspec(self, node: ast.ParamSpec, parent: NodeNG) -> nodes.ParamSpec:
1503+
"""Visit a ParamSpec node by returning a fresh instance of it."""
1504+
newnode = nodes.ParamSpec(
1505+
lineno=node.lineno,
1506+
col_offset=node.col_offset,
1507+
end_lineno=node.end_lineno,
1508+
end_col_offset=node.end_col_offset,
1509+
parent=parent,
1510+
)
1511+
newnode.postinit(node.name)
1512+
return newnode
1513+
14961514
def visit_pass(self, node: ast.Pass, parent: NodeNG) -> nodes.Pass:
14971515
"""Visit a Pass node by returning a fresh instance of it."""
14981516
return nodes.Pass(

doc/api/astroid.nodes.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ Nodes
6767
astroid.nodes.Module
6868
astroid.nodes.Name
6969
astroid.nodes.Nonlocal
70+
astroid.nodes.ParamSpec
7071
astroid.nodes.Pass
7172
astroid.nodes.Raise
7273
astroid.nodes.Return
@@ -204,6 +205,8 @@ Nodes
204205

205206
.. autoclass:: astroid.nodes.Nonlocal
206207

208+
.. autoclass:: astroid.nodes.ParamSpec
209+
207210
.. autoclass:: astroid.nodes.Pass
208211

209212
.. autoclass:: astroid.nodes.Raise

tests/test_type_params.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from astroid import extract_node
88
from astroid.const import PY312_PLUS
9-
from astroid.nodes import Subscript, TypeAlias, TypeVar
9+
from astroid.nodes import ParamSpec, Subscript, TypeAlias, TypeVar
1010

1111

1212
@pytest.mark.skipif(not PY312_PLUS, reason="Requires Python 3.12 or higher")
@@ -23,6 +23,14 @@ def test_type_alias() -> None:
2323
assert all(elt.name == "float" for elt in node.value.slice.elts)
2424

2525

26+
@pytest.mark.skipif(not PY312_PLUS, reason="Requires Python 3.12 or higher")
27+
def test_type_param_spec() -> None:
28+
node = extract_node("type Alias[**P] = Callable[P, int]")
29+
params = node.type_params[0]
30+
assert isinstance(params, ParamSpec)
31+
assert params.name == "P"
32+
33+
2634
@pytest.mark.skipif(not PY312_PLUS, reason="Requires Python 3.12 or higher")
2735
def test_type_param() -> None:
2836
func_node = extract_node("def func[T]() -> T: ...")

0 commit comments

Comments
 (0)