Skip to content

Commit 2196004

Browse files
authored
TYP use bool instead of bool_t in pandas/core/generic.py (pandas-dev#40175)
1 parent b2f9e1f commit 2196004

File tree

5 files changed

+151
-15
lines changed

5 files changed

+151
-15
lines changed

.pre-commit-config.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,8 @@ repos:
206206
files: ^pandas/core/
207207
exclude: ^pandas/core/api\.py$
208208
types: [python]
209+
- id: no-bool-in-core-generic
210+
name: Use bool_t instead of bool in pandas/core/generic.py
211+
entry: python scripts/no_bool_in_generic.py
212+
language: python
213+
files: ^pandas/core/generic\.py$

LICENSES/PYUPGRADE_LICENSE

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
Copyright (c) 2017 Anthony Sottile
2+
3+
Permission is hereby granted, free of charge, to any person obtaining a copy
4+
of this software and associated documentation files (the "Software"), to deal
5+
in the Software without restriction, including without limitation the rights
6+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7+
copies of the Software, and to permit persons to whom the Software is
8+
furnished to do so, subject to the following conditions:
9+
10+
The above copyright notice and this permission notice shall be included in
11+
all copies or substantial portions of the Software.
12+
13+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19+
THE SOFTWARE.

pandas/core/generic.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ class NDFrame(PandasObject, SelectionMixin, indexing.IndexingMixin):
232232
def __init__(
233233
self,
234234
data: Manager,
235-
copy: bool = False,
235+
copy: bool_t = False,
236236
attrs: Mapping[Hashable, Any] | None = None,
237237
):
238238
# copy kwarg is retained for mypy compat, is not used
@@ -249,7 +249,7 @@ def __init__(
249249

250250
@classmethod
251251
def _init_mgr(
252-
cls, mgr, axes, dtype: Dtype | None = None, copy: bool = False
252+
cls, mgr, axes, dtype: Dtype | None = None, copy: bool_t = False
253253
) -> Manager:
254254
""" passed a manager and a axes dict """
255255
for a, axe in axes.items():
@@ -377,8 +377,8 @@ def flags(self) -> Flags:
377377
def set_flags(
378378
self: FrameOrSeries,
379379
*,
380-
copy: bool = False,
381-
allows_duplicate_labels: bool | None = None,
380+
copy: bool_t = False,
381+
allows_duplicate_labels: bool_t | None = None,
382382
) -> FrameOrSeries:
383383
"""
384384
Return a new object with updated flags.
@@ -467,7 +467,7 @@ def _data(self):
467467
_stat_axis_name = "index"
468468
_AXIS_ORDERS: list[str]
469469
_AXIS_TO_AXIS_NUMBER: dict[Axis, int] = {0: 0, "index": 0, "rows": 0}
470-
_AXIS_REVERSED: bool
470+
_AXIS_REVERSED: bool_t
471471
_info_axis_number: int
472472
_info_axis_name: str
473473
_AXIS_LEN: int
@@ -494,7 +494,7 @@ def _construct_axes_dict(self, axes=None, **kwargs):
494494
@final
495495
@classmethod
496496
def _construct_axes_from_arguments(
497-
cls, args, kwargs, require_all: bool = False, sentinel=None
497+
cls, args, kwargs, require_all: bool_t = False, sentinel=None
498498
):
499499
"""
500500
Construct and returns axes if supplied in args/kwargs.
@@ -714,11 +714,11 @@ def set_axis(self: FrameOrSeries, labels, *, inplace: Literal[True]) -> None:
714714

715715
@overload
716716
def set_axis(
717-
self: FrameOrSeries, labels, axis: Axis = ..., inplace: bool = ...
717+
self: FrameOrSeries, labels, axis: Axis = ..., inplace: bool_t = ...
718718
) -> FrameOrSeries | None:
719719
...
720720

721-
def set_axis(self, labels, axis: Axis = 0, inplace: bool = False):
721+
def set_axis(self, labels, axis: Axis = 0, inplace: bool_t = False):
722722
"""
723723
Assign desired index to given axis.
724724
@@ -749,7 +749,7 @@ def set_axis(self, labels, axis: Axis = 0, inplace: bool = False):
749749
return self._set_axis_nocheck(labels, axis, inplace)
750750

751751
@final
752-
def _set_axis_nocheck(self, labels, axis: Axis, inplace: bool):
752+
def _set_axis_nocheck(self, labels, axis: Axis, inplace: bool_t):
753753
# NDFrame.rename with inplace=False calls set_axis(inplace=True) on a copy.
754754
if inplace:
755755
setattr(self, self._get_axis_name(axis), labels)
@@ -993,8 +993,8 @@ def rename(
993993
index: Renamer | None = None,
994994
columns: Renamer | None = None,
995995
axis: Axis | None = None,
996-
copy: bool = True,
997-
inplace: bool = False,
996+
copy: bool_t = True,
997+
inplace: bool_t = False,
998998
level: Level | None = None,
999999
errors: str = "ignore",
10001000
) -> FrameOrSeries | None:
@@ -1400,13 +1400,13 @@ def _set_axis_name(self, name, axis=0, inplace=False):
14001400
# Comparison Methods
14011401

14021402
@final
1403-
def _indexed_same(self, other) -> bool:
1403+
def _indexed_same(self, other) -> bool_t:
14041404
return all(
14051405
self._get_axis(a).equals(other._get_axis(a)) for a in self._AXIS_ORDERS
14061406
)
14071407

14081408
@final
1409-
def equals(self, other: object) -> bool:
1409+
def equals(self, other: object) -> bool_t:
14101410
"""
14111411
Test whether two objects contain the same elements.
14121412
@@ -4992,15 +4992,15 @@ def filter(
49924992
return self.reindex(**{name: [r for r in items if r in labels]})
49934993
elif like:
49944994

4995-
def f(x) -> bool:
4995+
def f(x) -> bool_t:
49964996
assert like is not None # needed for mypy
49974997
return like in ensure_str(x)
49984998

49994999
values = labels.map(f)
50005000
return self.loc(axis=axis)[values]
50015001
elif regex:
50025002

5003-
def f(x) -> bool:
5003+
def f(x) -> bool_t:
50045004
return matcher.search(ensure_str(x)) is not None
50055005

50065006
matcher = re.compile(regex)

scripts/no_bool_in_generic.py

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""
2+
Check that pandas/core/generic.py doesn't use bool as a type annotation.
3+
4+
There is already the method `bool`, so the alias `bool_t` should be used instead.
5+
6+
This is meant to be run as a pre-commit hook - to run it manually, you can do:
7+
8+
pre-commit run no-bool-in-core-generic --all-files
9+
10+
The function `visit` is adapted from a function by the same name in pyupgrade:
11+
https://github.com/asottile/pyupgrade/blob/5495a248f2165941c5d3b82ac3226ba7ad1fa59d/pyupgrade/_data.py#L70-L113
12+
"""
13+
14+
import argparse
15+
import ast
16+
import collections
17+
from typing import (
18+
Dict,
19+
List,
20+
Optional,
21+
Sequence,
22+
Tuple,
23+
)
24+
25+
26+
def visit(tree: ast.Module) -> Dict[int, List[int]]:
27+
"Step through tree, recording when nodes are in annotations."
28+
in_annotation = False
29+
nodes: List[Tuple[bool, ast.AST]] = [(in_annotation, tree)]
30+
to_replace = collections.defaultdict(list)
31+
32+
while nodes:
33+
in_annotation, node = nodes.pop()
34+
35+
if isinstance(node, ast.Name) and in_annotation and node.id == "bool":
36+
to_replace[node.lineno].append(node.col_offset)
37+
38+
for name in reversed(node._fields):
39+
value = getattr(node, name)
40+
if name in {"annotation", "returns"}:
41+
next_in_annotation = True
42+
else:
43+
next_in_annotation = in_annotation
44+
if isinstance(value, ast.AST):
45+
nodes.append((next_in_annotation, value))
46+
elif isinstance(value, list):
47+
for value in reversed(value):
48+
if isinstance(value, ast.AST):
49+
nodes.append((next_in_annotation, value))
50+
51+
return to_replace
52+
53+
54+
def replace_bool_with_bool_t(to_replace, content: str) -> str:
55+
new_lines = []
56+
57+
for n, line in enumerate(content.splitlines(), start=1):
58+
if n in to_replace:
59+
for col_offset in reversed(to_replace[n]):
60+
line = line[:col_offset] + "bool_t" + line[col_offset + 4 :]
61+
new_lines.append(line)
62+
return "\n".join(new_lines)
63+
64+
65+
def check_for_bool_in_generic(content: str) -> Tuple[bool, str]:
66+
tree = ast.parse(content)
67+
to_replace = visit(tree)
68+
69+
if not to_replace:
70+
mutated = False
71+
return mutated, content
72+
73+
mutated = True
74+
return mutated, replace_bool_with_bool_t(to_replace, content)
75+
76+
77+
def main(argv: Optional[Sequence[str]] = None) -> None:
78+
parser = argparse.ArgumentParser()
79+
parser.add_argument("paths", nargs="*")
80+
args = parser.parse_args(argv)
81+
82+
for path in args.paths:
83+
with open(path, encoding="utf-8") as fd:
84+
content = fd.read()
85+
mutated, new_content = check_for_bool_in_generic(content)
86+
if mutated:
87+
with open(path, "w", encoding="utf-8") as fd:
88+
fd.write(new_content)
89+
90+
91+
if __name__ == "__main__":
92+
main()
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from scripts.no_bool_in_generic import check_for_bool_in_generic
2+
3+
BAD_FILE = "def foo(a: bool) -> bool:\n return bool(0)"
4+
GOOD_FILE = "def foo(a: bool_t) -> bool_t:\n return bool(0)"
5+
6+
7+
def test_bad_file_with_replace():
8+
content = BAD_FILE
9+
mutated, result = check_for_bool_in_generic(content)
10+
expected = GOOD_FILE
11+
assert result == expected
12+
assert mutated
13+
14+
15+
def test_good_file_with_replace():
16+
content = GOOD_FILE
17+
mutated, result = check_for_bool_in_generic(content)
18+
expected = content
19+
assert result == expected
20+
assert not mutated

0 commit comments

Comments
 (0)