|
| 1 | +""" |
| 2 | +Check that pandas/core/generic.py doesn't use bool as a type annotation. |
| 3 | +
|
| 4 | +There is already the method `bool`, so the alias `bool_t` should be used instead. |
| 5 | +
|
| 6 | +This is meant to be run as a pre-commit hook - to run it manually, you can do: |
| 7 | +
|
| 8 | + pre-commit run no-bool-in-core-generic --all-files |
| 9 | +
|
| 10 | +The function `visit` is adapted from a function by the same name in pyupgrade: |
| 11 | +https://github.com/asottile/pyupgrade/blob/5495a248f2165941c5d3b82ac3226ba7ad1fa59d/pyupgrade/_data.py#L70-L113 |
| 12 | +""" |
| 13 | + |
| 14 | +import argparse |
| 15 | +import ast |
| 16 | +import collections |
| 17 | +from typing import ( |
| 18 | + Dict, |
| 19 | + List, |
| 20 | + Optional, |
| 21 | + Sequence, |
| 22 | + Tuple, |
| 23 | +) |
| 24 | + |
| 25 | + |
| 26 | +def visit( |
| 27 | + tree: ast.Module, |
| 28 | +) -> Dict[int, List[int]]: |
| 29 | + in_annotation = False |
| 30 | + nodes: List[Tuple[bool, ast.AST]] = [(in_annotation, tree)] |
| 31 | + to_replace = collections.defaultdict(list) |
| 32 | + |
| 33 | + while nodes: |
| 34 | + in_annotation, node = nodes.pop() |
| 35 | + |
| 36 | + if isinstance(node, ast.Name) and in_annotation and node.id == "bool": |
| 37 | + to_replace[node.lineno].append(node.col_offset) |
| 38 | + |
| 39 | + for name in reversed(node._fields): |
| 40 | + value = getattr(node, name) |
| 41 | + if name in {"annotation", "returns"}: |
| 42 | + next_in_annotation = True |
| 43 | + else: |
| 44 | + next_in_annotation = in_annotation |
| 45 | + if isinstance(value, ast.AST): |
| 46 | + nodes.append((next_in_annotation, value)) |
| 47 | + elif isinstance(value, list): |
| 48 | + for value in reversed(value): |
| 49 | + if isinstance(value, ast.AST): |
| 50 | + nodes.append((next_in_annotation, value)) |
| 51 | + |
| 52 | + return to_replace |
| 53 | + |
| 54 | + |
| 55 | +def replace_bool_with_bool_t(to_replace, content: str) -> str: |
| 56 | + new_lines = [] |
| 57 | + |
| 58 | + for n, line in enumerate(content.splitlines(), start=1): |
| 59 | + if n in to_replace: |
| 60 | + for col_offset in reversed(to_replace[n]): |
| 61 | + line = line[:col_offset] + "bool_t" + line[col_offset + 4 :] |
| 62 | + new_lines.append(line) |
| 63 | + return "\n".join(new_lines) |
| 64 | + |
| 65 | + |
| 66 | +def check_for_bool_in_generic(content: str) -> Tuple[bool, str]: |
| 67 | + tree = ast.parse(content) |
| 68 | + to_replace = visit(tree) |
| 69 | + if not to_replace: |
| 70 | + return False, content |
| 71 | + return True, replace_bool_with_bool_t(to_replace, content) |
| 72 | + |
| 73 | + |
| 74 | +def main(argv: Optional[Sequence[str]] = None) -> None: |
| 75 | + parser = argparse.ArgumentParser() |
| 76 | + parser.add_argument("paths", nargs="*") |
| 77 | + args = parser.parse_args(argv) |
| 78 | + |
| 79 | + for path in args.paths: |
| 80 | + with open(path, encoding="utf-8") as fd: |
| 81 | + content = fd.read() |
| 82 | + replace, new_content = check_for_bool_in_generic(content) |
| 83 | + if replace: |
| 84 | + with open(path, "w", encoding="utf-8") as fd: |
| 85 | + fd.write(new_content) |
| 86 | + |
| 87 | + |
| 88 | +if __name__ == "__main__": |
| 89 | + main() |
0 commit comments